feat: consolidate reasoning model checks (#3862)

This commit is contained in:
cthomas
2025-08-11 16:55:45 -07:00
committed by GitHub
parent c8b370466e
commit 5cf807574f
10 changed files with 32 additions and 31 deletions

View File

@@ -96,7 +96,7 @@ all_configs = [
"openai-gpt-4o-mini.json", "openai-gpt-4o-mini.json",
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop # "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
"claude-3-5-sonnet.json", "claude-3-5-sonnet.json",
"claude-3-7-sonnet.json", "claude-4-sonnet-extended.json",
"claude-3-7-sonnet-extended.json", "claude-3-7-sonnet-extended.json",
"gemini-1.5-pro.json", "gemini-1.5-pro.json",
"gemini-2.5-flash-vertex.json", "gemini-2.5-flash-vertex.json",

View File

@@ -19,7 +19,7 @@ jobs:
- "openai-gpt-4o-mini.json" - "openai-gpt-4o-mini.json"
- "azure-gpt-4o-mini.json" - "azure-gpt-4o-mini.json"
- "claude-3-5-sonnet.json" - "claude-3-5-sonnet.json"
- "claude-3-7-sonnet.json" - "claude-4-sonnet-extended.json"
- "claude-3-7-sonnet-extended.json" - "claude-3-7-sonnet-extended.json"
- "gemini-pro.json" - "gemini-pro.json"
- "gemini-vertex.json" - "gemini-vertex.json"

View File

@@ -182,7 +182,7 @@ class AnthropicClient(LLMClientBase):
} }
# Extended Thinking # Extended Thinking
if llm_config.enable_reasoner: if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
data["thinking"] = { data["thinking"] = {
"type": "enabled", "type": "enabled",
"budget_tokens": llm_config.max_reasoning_tokens, "budget_tokens": llm_config.max_reasoning_tokens,
@@ -200,7 +200,7 @@ class AnthropicClient(LLMClientBase):
# Special case for summarization path # Special case for summarization path
tools_for_request = None tools_for_request = None
tool_choice = None tool_choice = None
elif llm_config.enable_reasoner: elif self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
# NOTE: reasoning models currently do not allow for `any` # NOTE: reasoning models currently do not allow for `any`
tool_choice = {"type": "auto", "disable_parallel_tool_use": True} tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
tools_for_request = [OpenAITool(function=f) for f in tools] tools_for_request = [OpenAITool(function=f) for f in tools]
@@ -296,6 +296,13 @@ class AnthropicClient(LLMClientBase):
token_count -= 8 token_count -= 8
return token_count return token_count
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return (
llm_config.model.startswith("claude-3-7-sonnet")
or llm_config.model.startswith("claude-sonnet-4")
or llm_config.model.startswith("claude-opus-4")
)
@trace_method @trace_method
def handle_llm_error(self, e: Exception) -> Exception: def handle_llm_error(self, e: Exception) -> Exception:
if isinstance(e, anthropic.APITimeoutError): if isinstance(e, anthropic.APITimeoutError):

View File

@@ -504,6 +504,9 @@ class GoogleVertexClient(LLMClientBase):
return 1 return 1
return 0 return 0
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro")
@trace_method @trace_method
def handle_llm_error(self, e: Exception) -> Exception: def handle_llm_error(self, e: Exception) -> Exception:
# Fallback to base implementation # Fallback to base implementation

View File

@@ -174,6 +174,10 @@ class LLMClientBase:
""" """
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}") raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
@abstractmethod
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
raise NotImplementedError
@abstractmethod @abstractmethod
def handle_llm_error(self, e: Exception) -> Exception: def handle_llm_error(self, e: Exception) -> Exception:
""" """

View File

@@ -276,6 +276,9 @@ class OpenAIClient(LLMClientBase):
response: ChatCompletion = await client.chat.completions.create(**request_data) response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump() return response.model_dump()
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return is_openai_reasoning_model(llm_config.model)
@trace_method @trace_method
def convert_response_to_chat_completion( def convert_response_to_chat_completion(
self, self,
@@ -298,7 +301,7 @@ class OpenAIClient(LLMClientBase):
) )
# If we used a reasoning model, create a content part for the ommitted reasoning # If we used a reasoning model, create a content part for the ommitted reasoning
if is_openai_reasoning_model(llm_config.model): if self.is_reasoning_model(llm_config):
chat_completion_response.choices[0].message.omitted_reasoning_content = True chat_completion_response.choices[0].message.omitted_reasoning_content = True
return chat_completion_response return chat_completion_response

View File

@@ -94,6 +94,9 @@ class LLMConfig(BaseModel):
""" """
model = values.get("model") model = values.get("model")
if model is None:
return values
# Define models where we want put_inner_thoughts_in_kwargs to be False # Define models where we want put_inner_thoughts_in_kwargs to be False
avoid_put_inner_thoughts_in_kwargs = ["gpt-4"] avoid_put_inner_thoughts_in_kwargs = ["gpt-4"]
@@ -107,25 +110,13 @@ class LLMConfig(BaseModel):
if is_openai_reasoning_model(model): if is_openai_reasoning_model(model):
values["put_inner_thoughts_in_kwargs"] = False values["put_inner_thoughts_in_kwargs"] = False
if values.get("enable_reasoner") and values.get("model_endpoint_type") == "anthropic": if values.get("model_endpoint_type") == "anthropic" and (
model.startswith("claude-3-7-sonnet") or model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")
):
values["put_inner_thoughts_in_kwargs"] = False values["put_inner_thoughts_in_kwargs"] = False
return values return values
@model_validator(mode="after")
def issue_warning_for_reasoning_constraints(self) -> "LLMConfig":
if self.enable_reasoner:
if self.max_reasoning_tokens is None:
logger.warning("max_reasoning_tokens must be set when enable_reasoner is True")
if self.max_tokens is not None and self.max_reasoning_tokens >= self.max_tokens:
logger.warning("max_tokens must be greater than max_reasoning_tokens (thinking budget)")
if self.put_inner_thoughts_in_kwargs:
logger.debug("Extended thinking is not compatible with put_inner_thoughts_in_kwargs")
elif self.max_reasoning_tokens and not self.enable_reasoner:
logger.warning("model will not use reasoning unless enable_reasoner is set to True")
return self
@classmethod @classmethod
def default_config(cls, model_name: str): def default_config(cls, model_name: str):
""" """

View File

@@ -1,8 +0,0 @@
{
"model": "claude-3-7-sonnet-20250219",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null,
"context_window": 200000,
"put_inner_thoughts_in_kwargs": true
}

View File

@@ -4,5 +4,7 @@
"model_endpoint": "https://api.anthropic.com/v1", "model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null, "model_wrapper": null,
"context_window": 200000, "context_window": 200000,
"put_inner_thoughts_in_kwargs": true "put_inner_thoughts_in_kwargs": false,
"enable_reasoner": true,
"max_reasoning_tokens": 1024
} }

View File

@@ -118,9 +118,8 @@ all_configs = [
"openai-o3.json", "openai-o3.json",
"openai-o4-mini.json", "openai-o4-mini.json",
"azure-gpt-4o-mini.json", "azure-gpt-4o-mini.json",
"claude-4-sonnet.json", "claude-4-sonnet-extended.json",
"claude-3-5-sonnet.json", "claude-3-5-sonnet.json",
"claude-3-7-sonnet.json",
"claude-3-7-sonnet-extended.json", "claude-3-7-sonnet-extended.json",
"bedrock-claude-4-sonnet.json", "bedrock-claude-4-sonnet.json",
"gemini-1.5-pro.json", "gemini-1.5-pro.json",