feat: consolidate reasoning model checks (#3862)
This commit is contained in:
@@ -182,7 +182,7 @@ class AnthropicClient(LLMClientBase):
|
||||
}
|
||||
|
||||
# Extended Thinking
|
||||
if llm_config.enable_reasoner:
|
||||
if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
|
||||
data["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": llm_config.max_reasoning_tokens,
|
||||
@@ -200,7 +200,7 @@ class AnthropicClient(LLMClientBase):
|
||||
# Special case for summarization path
|
||||
tools_for_request = None
|
||||
tool_choice = None
|
||||
elif llm_config.enable_reasoner:
|
||||
elif self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
|
||||
# NOTE: reasoning models currently do not allow for `any`
|
||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools]
|
||||
@@ -296,6 +296,13 @@ class AnthropicClient(LLMClientBase):
|
||||
token_count -= 8
|
||||
return token_count
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return (
|
||||
llm_config.model.startswith("claude-3-7-sonnet")
|
||||
or llm_config.model.startswith("claude-sonnet-4")
|
||||
or llm_config.model.startswith("claude-opus-4")
|
||||
)
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
if isinstance(e, anthropic.APITimeoutError):
|
||||
|
||||
@@ -504,6 +504,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro")
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
# Fallback to base implementation
|
||||
|
||||
@@ -174,6 +174,10 @@ class LLMClientBase:
|
||||
"""
|
||||
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
"""
|
||||
|
||||
@@ -276,6 +276,9 @@ class OpenAIClient(LLMClientBase):
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return is_openai_reasoning_model(llm_config.model)
|
||||
|
||||
@trace_method
|
||||
def convert_response_to_chat_completion(
|
||||
self,
|
||||
@@ -298,7 +301,7 @@ class OpenAIClient(LLMClientBase):
|
||||
)
|
||||
|
||||
# If we used a reasoning model, create a content part for the ommitted reasoning
|
||||
if is_openai_reasoning_model(llm_config.model):
|
||||
if self.is_reasoning_model(llm_config):
|
||||
chat_completion_response.choices[0].message.omitted_reasoning_content = True
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
Reference in New Issue
Block a user