diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 65f8fce2..ecb2c663 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -185,31 +185,50 @@ class LLMConfig(BaseModel): + (f" [ip={self.model_endpoint}]" if self.model_endpoint else "") ) + @classmethod + def is_openai_reasoning_model(cls, config: "LLMConfig") -> bool: + return config.model_endpoint_type == "openai" and ( + config.model.startswith("o1") or config.model.startswith("o3") or config.model.startswith("o4") + ) + + @classmethod + def is_anthropic_reasoning_model(cls, config: "LLMConfig") -> bool: + return config.model_endpoint_type == "anthropic" and ( + config.model.startswith("claude-opus-4") + or config.model.startswith("claude-sonnet-4") + or config.model.startswith("claude-3-7-sonnet") + ) + + @classmethod + def is_google_vertex_reasoning_model(cls, config: "LLMConfig") -> bool: + return config.model_endpoint_type == "google_vertex" and ( + config.model.startswith("gemini-2.5-flash") or config.model.startswith("gemini-2.5-pro") + ) + @classmethod def apply_reasoning_setting_to_config(cls, config: "LLMConfig", reasoning: bool): if not reasoning: + if cls.is_openai_reasoning_model(config) or config.model.startswith("gemini-2.5-pro"): + raise ValueError("Reasoning cannot be disabled for OpenAI o1/o3 models") config.put_inner_thoughts_in_kwargs = False config.enable_reasoner = False else: config.enable_reasoner = True - if ( - config.model_endpoint_type == "anthropic" - and ("claude-opus-4" in config.model or "claude-sonnet-4" in config.model or "claude-3-7-sonnet" in config.model) - ) or ( - config.model_endpoint_type == "google_vertex" and ("gemini-2.5-flash" in config.model or "gemini-2.0-pro" in config.model) - ): + if cls.is_anthropic_reasoning_model(config): config.put_inner_thoughts_in_kwargs = False if config.max_reasoning_tokens == 0: config.max_reasoning_tokens = 1024 - elif config.model_endpoint_type == "openai" and ( - config.model.startswith("o1") or config.model.startswith("o3") or config.model.startswith("o4") - ): + elif cls.is_google_vertex_reasoning_model(config): + # Handle as non-reasoner until we support summary + config.put_inner_thoughts_in_kwargs = True + if config.max_reasoning_tokens == 0: + config.max_reasoning_tokens = 1024 + elif cls.is_openai_reasoning_model(config): config.put_inner_thoughts_in_kwargs = False if config.reasoning_effort is None: config.reasoning_effort = "medium" else: config.put_inner_thoughts_in_kwargs = True - config.enable_reasoner = False return config diff --git a/tests/test_providers.py b/tests/test_providers.py index 0b7c50e7..cb7a54bd 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,5 +1,8 @@ +from typing import Literal, Optional + import pytest +from letta.schemas.llm_config import LLMConfig from letta.schemas.providers import ( AnthropicProvider, AzureProvider, @@ -314,3 +317,51 @@ async def test_provider_llm_models_consistency(): assert model.handle.startswith(f"{provider.name}/") assert model.provider_name == provider.name assert model.context_window > 0 + + +@pytest.mark.parametrize( + "handle, expected_enable_reasoner, expected_put_inner_thoughts_in_kwargs, expected_max_reasoning_tokens, expected_reasoning_effort, expected_exception", + [ + ("openai/gpt-4o-mini", True, True, 0, None, None), + ("openai/gpt-4o-mini", False, False, 0, None, None), + ("openai/o3-mini", True, False, 0, "medium", None), + ("openai/o3-mini", False, False, 0, None, ValueError), + ("anthropic/claude-3.5-sonnet", True, True, 0, None, None), + ("anthropic/claude-3.5-sonnet", False, False, 0, None, None), + ("anthropic/claude-3-7-sonnet", True, False, 1024, None, None), + ("anthropic/claude-3-7-sonnet", False, False, 0, None, None), + ("anthropic/claude-sonnet-4", True, False, 1024, None, None), + ("anthropic/claude-sonnet-4", False, False, 0, None, None), + ("google_vertex/gemini-2.0-flash", True, True, 0, None, None), + ("google_vertex/gemini-2.0-flash", False, False, 0, None, None), + ("google_vertex/gemini-2.5-flash", True, True, 1024, None, None), + ("google_vertex/gemini-2.5-flash", False, False, 0, None, None), + ("google_vertex/gemini-2.5-pro", True, True, 1024, None, None), + ("google_vertex/gemini-2.5-pro", False, False, 0, None, ValueError), + ], +) +def test_reasoning_toggle_by_provider( + handle: str, + expected_enable_reasoner: bool, + expected_put_inner_thoughts_in_kwargs: bool, + expected_max_reasoning_tokens: int, + expected_reasoning_effort: Optional[Literal["minimal", "low", "medium", "high"]], + expected_exception: Optional[Exception], +): + model_endpoint_type, model = handle.split("/") + config = LLMConfig( + model_endpoint_type=model_endpoint_type, + model=model, + handle=handle, + context_window=1024, + ) + if expected_exception: + with pytest.raises(expected_exception): + LLMConfig.apply_reasoning_setting_to_config(config, reasoning=expected_enable_reasoner) + else: + new_config = LLMConfig.apply_reasoning_setting_to_config(config, reasoning=expected_enable_reasoner) + + assert new_config.enable_reasoner == expected_enable_reasoner + assert new_config.put_inner_thoughts_in_kwargs == expected_put_inner_thoughts_in_kwargs + assert new_config.reasoning_effort == expected_reasoning_effort + assert new_config.max_reasoning_tokens == expected_max_reasoning_tokens