feat: handle all cases for reasoning toggle (#3857)
This commit is contained in:
@@ -185,31 +185,50 @@ class LLMConfig(BaseModel):
|
||||
+ (f" [ip={self.model_endpoint}]" if self.model_endpoint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_openai_reasoning_model(cls, config: "LLMConfig") -> bool:
|
||||
return config.model_endpoint_type == "openai" and (
|
||||
config.model.startswith("o1") or config.model.startswith("o3") or config.model.startswith("o4")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_anthropic_reasoning_model(cls, config: "LLMConfig") -> bool:
|
||||
return config.model_endpoint_type == "anthropic" and (
|
||||
config.model.startswith("claude-opus-4")
|
||||
or config.model.startswith("claude-sonnet-4")
|
||||
or config.model.startswith("claude-3-7-sonnet")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_google_vertex_reasoning_model(cls, config: "LLMConfig") -> bool:
|
||||
return config.model_endpoint_type == "google_vertex" and (
|
||||
config.model.startswith("gemini-2.5-flash") or config.model.startswith("gemini-2.5-pro")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def apply_reasoning_setting_to_config(cls, config: "LLMConfig", reasoning: bool):
|
||||
if not reasoning:
|
||||
if cls.is_openai_reasoning_model(config) or config.model.startswith("gemini-2.5-pro"):
|
||||
raise ValueError("Reasoning cannot be disabled for OpenAI o1/o3 models")
|
||||
config.put_inner_thoughts_in_kwargs = False
|
||||
config.enable_reasoner = False
|
||||
|
||||
else:
|
||||
config.enable_reasoner = True
|
||||
if (
|
||||
config.model_endpoint_type == "anthropic"
|
||||
and ("claude-opus-4" in config.model or "claude-sonnet-4" in config.model or "claude-3-7-sonnet" in config.model)
|
||||
) or (
|
||||
config.model_endpoint_type == "google_vertex" and ("gemini-2.5-flash" in config.model or "gemini-2.0-pro" in config.model)
|
||||
):
|
||||
if cls.is_anthropic_reasoning_model(config):
|
||||
config.put_inner_thoughts_in_kwargs = False
|
||||
if config.max_reasoning_tokens == 0:
|
||||
config.max_reasoning_tokens = 1024
|
||||
elif config.model_endpoint_type == "openai" and (
|
||||
config.model.startswith("o1") or config.model.startswith("o3") or config.model.startswith("o4")
|
||||
):
|
||||
elif cls.is_google_vertex_reasoning_model(config):
|
||||
# Handle as non-reasoner until we support summary
|
||||
config.put_inner_thoughts_in_kwargs = True
|
||||
if config.max_reasoning_tokens == 0:
|
||||
config.max_reasoning_tokens = 1024
|
||||
elif cls.is_openai_reasoning_model(config):
|
||||
config.put_inner_thoughts_in_kwargs = False
|
||||
if config.reasoning_effort is None:
|
||||
config.reasoning_effort = "medium"
|
||||
else:
|
||||
config.put_inner_thoughts_in_kwargs = True
|
||||
config.enable_reasoner = False
|
||||
|
||||
return config
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers import (
|
||||
AnthropicProvider,
|
||||
AzureProvider,
|
||||
@@ -314,3 +317,51 @@ async def test_provider_llm_models_consistency():
|
||||
assert model.handle.startswith(f"{provider.name}/")
|
||||
assert model.provider_name == provider.name
|
||||
assert model.context_window > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"handle, expected_enable_reasoner, expected_put_inner_thoughts_in_kwargs, expected_max_reasoning_tokens, expected_reasoning_effort, expected_exception",
|
||||
[
|
||||
("openai/gpt-4o-mini", True, True, 0, None, None),
|
||||
("openai/gpt-4o-mini", False, False, 0, None, None),
|
||||
("openai/o3-mini", True, False, 0, "medium", None),
|
||||
("openai/o3-mini", False, False, 0, None, ValueError),
|
||||
("anthropic/claude-3.5-sonnet", True, True, 0, None, None),
|
||||
("anthropic/claude-3.5-sonnet", False, False, 0, None, None),
|
||||
("anthropic/claude-3-7-sonnet", True, False, 1024, None, None),
|
||||
("anthropic/claude-3-7-sonnet", False, False, 0, None, None),
|
||||
("anthropic/claude-sonnet-4", True, False, 1024, None, None),
|
||||
("anthropic/claude-sonnet-4", False, False, 0, None, None),
|
||||
("google_vertex/gemini-2.0-flash", True, True, 0, None, None),
|
||||
("google_vertex/gemini-2.0-flash", False, False, 0, None, None),
|
||||
("google_vertex/gemini-2.5-flash", True, True, 1024, None, None),
|
||||
("google_vertex/gemini-2.5-flash", False, False, 0, None, None),
|
||||
("google_vertex/gemini-2.5-pro", True, True, 1024, None, None),
|
||||
("google_vertex/gemini-2.5-pro", False, False, 0, None, ValueError),
|
||||
],
|
||||
)
|
||||
def test_reasoning_toggle_by_provider(
|
||||
handle: str,
|
||||
expected_enable_reasoner: bool,
|
||||
expected_put_inner_thoughts_in_kwargs: bool,
|
||||
expected_max_reasoning_tokens: int,
|
||||
expected_reasoning_effort: Optional[Literal["minimal", "low", "medium", "high"]],
|
||||
expected_exception: Optional[Exception],
|
||||
):
|
||||
model_endpoint_type, model = handle.split("/")
|
||||
config = LLMConfig(
|
||||
model_endpoint_type=model_endpoint_type,
|
||||
model=model,
|
||||
handle=handle,
|
||||
context_window=1024,
|
||||
)
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
LLMConfig.apply_reasoning_setting_to_config(config, reasoning=expected_enable_reasoner)
|
||||
else:
|
||||
new_config = LLMConfig.apply_reasoning_setting_to_config(config, reasoning=expected_enable_reasoner)
|
||||
|
||||
assert new_config.enable_reasoner == expected_enable_reasoner
|
||||
assert new_config.put_inner_thoughts_in_kwargs == expected_put_inner_thoughts_in_kwargs
|
||||
assert new_config.reasoning_effort == expected_reasoning_effort
|
||||
assert new_config.max_reasoning_tokens == expected_max_reasoning_tokens
|
||||
|
||||
Reference in New Issue
Block a user