feat: handle all cases for reasoning toggle (#3857)

This commit is contained in:
cthomas
2025-08-11 13:45:50 -07:00
committed by GitHub
parent 4dcbf0b8f2
commit 32a4ecae9d
2 changed files with 80 additions and 10 deletions

View File

@@ -185,31 +185,50 @@ class LLMConfig(BaseModel):
+ (f" [ip={self.model_endpoint}]" if self.model_endpoint else "")
)
@classmethod
def is_openai_reasoning_model(cls, config: "LLMConfig") -> bool:
return config.model_endpoint_type == "openai" and (
config.model.startswith("o1") or config.model.startswith("o3") or config.model.startswith("o4")
)
@classmethod
def is_anthropic_reasoning_model(cls, config: "LLMConfig") -> bool:
return config.model_endpoint_type == "anthropic" and (
config.model.startswith("claude-opus-4")
or config.model.startswith("claude-sonnet-4")
or config.model.startswith("claude-3-7-sonnet")
)
@classmethod
def is_google_vertex_reasoning_model(cls, config: "LLMConfig") -> bool:
return config.model_endpoint_type == "google_vertex" and (
config.model.startswith("gemini-2.5-flash") or config.model.startswith("gemini-2.5-pro")
)
@classmethod
def apply_reasoning_setting_to_config(cls, config: "LLMConfig", reasoning: bool):
if not reasoning:
if cls.is_openai_reasoning_model(config) or config.model.startswith("gemini-2.5-pro"):
raise ValueError("Reasoning cannot be disabled for OpenAI o1/o3 models")
config.put_inner_thoughts_in_kwargs = False
config.enable_reasoner = False
else:
config.enable_reasoner = True
if (
config.model_endpoint_type == "anthropic"
and ("claude-opus-4" in config.model or "claude-sonnet-4" in config.model or "claude-3-7-sonnet" in config.model)
) or (
config.model_endpoint_type == "google_vertex" and ("gemini-2.5-flash" in config.model or "gemini-2.0-pro" in config.model)
):
if cls.is_anthropic_reasoning_model(config):
config.put_inner_thoughts_in_kwargs = False
if config.max_reasoning_tokens == 0:
config.max_reasoning_tokens = 1024
elif config.model_endpoint_type == "openai" and (
config.model.startswith("o1") or config.model.startswith("o3") or config.model.startswith("o4")
):
elif cls.is_google_vertex_reasoning_model(config):
# Handle as non-reasoner until we support summary
config.put_inner_thoughts_in_kwargs = True
if config.max_reasoning_tokens == 0:
config.max_reasoning_tokens = 1024
elif cls.is_openai_reasoning_model(config):
config.put_inner_thoughts_in_kwargs = False
if config.reasoning_effort is None:
config.reasoning_effort = "medium"
else:
config.put_inner_thoughts_in_kwargs = True
config.enable_reasoner = False
return config

View File

@@ -1,5 +1,8 @@
from typing import Literal, Optional
import pytest
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import (
AnthropicProvider,
AzureProvider,
@@ -314,3 +317,51 @@ async def test_provider_llm_models_consistency():
assert model.handle.startswith(f"{provider.name}/")
assert model.provider_name == provider.name
assert model.context_window > 0
@pytest.mark.parametrize(
"handle, expected_enable_reasoner, expected_put_inner_thoughts_in_kwargs, expected_max_reasoning_tokens, expected_reasoning_effort, expected_exception",
[
("openai/gpt-4o-mini", True, True, 0, None, None),
("openai/gpt-4o-mini", False, False, 0, None, None),
("openai/o3-mini", True, False, 0, "medium", None),
("openai/o3-mini", False, False, 0, None, ValueError),
("anthropic/claude-3.5-sonnet", True, True, 0, None, None),
("anthropic/claude-3.5-sonnet", False, False, 0, None, None),
("anthropic/claude-3-7-sonnet", True, False, 1024, None, None),
("anthropic/claude-3-7-sonnet", False, False, 0, None, None),
("anthropic/claude-sonnet-4", True, False, 1024, None, None),
("anthropic/claude-sonnet-4", False, False, 0, None, None),
("google_vertex/gemini-2.0-flash", True, True, 0, None, None),
("google_vertex/gemini-2.0-flash", False, False, 0, None, None),
("google_vertex/gemini-2.5-flash", True, True, 1024, None, None),
("google_vertex/gemini-2.5-flash", False, False, 0, None, None),
("google_vertex/gemini-2.5-pro", True, True, 1024, None, None),
("google_vertex/gemini-2.5-pro", False, False, 0, None, ValueError),
],
)
def test_reasoning_toggle_by_provider(
handle: str,
expected_enable_reasoner: bool,
expected_put_inner_thoughts_in_kwargs: bool,
expected_max_reasoning_tokens: int,
expected_reasoning_effort: Optional[Literal["minimal", "low", "medium", "high"]],
expected_exception: Optional[Exception],
):
model_endpoint_type, model = handle.split("/")
config = LLMConfig(
model_endpoint_type=model_endpoint_type,
model=model,
handle=handle,
context_window=1024,
)
if expected_exception:
with pytest.raises(expected_exception):
LLMConfig.apply_reasoning_setting_to_config(config, reasoning=expected_enable_reasoner)
else:
new_config = LLMConfig.apply_reasoning_setting_to_config(config, reasoning=expected_enable_reasoner)
assert new_config.enable_reasoner == expected_enable_reasoner
assert new_config.put_inner_thoughts_in_kwargs == expected_put_inner_thoughts_in_kwargs
assert new_config.reasoning_effort == expected_reasoning_effort
assert new_config.max_reasoning_tokens == expected_max_reasoning_tokens