feat: consolidate reasoning model checks (#3862)

This commit is contained in:
cthomas
2025-08-11 16:55:45 -07:00
committed by GitHub
parent c8b370466e
commit 5cf807574f
10 changed files with 32 additions and 31 deletions

View File

@@ -182,7 +182,7 @@ class AnthropicClient(LLMClientBase):
}
# Extended Thinking
if llm_config.enable_reasoner:
if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
data["thinking"] = {
"type": "enabled",
"budget_tokens": llm_config.max_reasoning_tokens,
@@ -200,7 +200,7 @@ class AnthropicClient(LLMClientBase):
# Special case for summarization path
tools_for_request = None
tool_choice = None
elif llm_config.enable_reasoner:
elif self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
# NOTE: reasoning models currently do not allow for `any`
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
tools_for_request = [OpenAITool(function=f) for f in tools]
@@ -296,6 +296,13 @@ class AnthropicClient(LLMClientBase):
token_count -= 8
return token_count
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return (
llm_config.model.startswith("claude-3-7-sonnet")
or llm_config.model.startswith("claude-sonnet-4")
or llm_config.model.startswith("claude-opus-4")
)
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:
if isinstance(e, anthropic.APITimeoutError):

View File

@@ -504,6 +504,9 @@ class GoogleVertexClient(LLMClientBase):
return 1
return 0
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro")
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:
# Fallback to base implementation

View File

@@ -174,6 +174,10 @@ class LLMClientBase:
"""
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
@abstractmethod
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
raise NotImplementedError
@abstractmethod
def handle_llm_error(self, e: Exception) -> Exception:
"""

View File

@@ -276,6 +276,9 @@ class OpenAIClient(LLMClientBase):
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return is_openai_reasoning_model(llm_config.model)
@trace_method
def convert_response_to_chat_completion(
self,
@@ -298,7 +301,7 @@ class OpenAIClient(LLMClientBase):
)
# If we used a reasoning model, create a content part for the ommitted reasoning
if is_openai_reasoning_model(llm_config.model):
if self.is_reasoning_model(llm_config):
chat_completion_response.choices[0].message.omitted_reasoning_content = True
return chat_completion_response