feat: Add ability to disable reasoning (#3594)

This commit is contained in:
Matthew Zhou
2025-07-28 15:30:10 -07:00
committed by GitHub
parent 2cd985ef8a
commit 84ea52172a
9 changed files with 185 additions and 83 deletions

View File

@@ -215,11 +215,7 @@ class AnthropicClient(LLMClientBase):
)
llm_config.put_inner_thoughts_in_kwargs = True
else:
if llm_config.put_inner_thoughts_in_kwargs:
# tool_choice_type other than "auto" only plays nice if thinking goes inside the tool calls
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
else:
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
tools_for_request = [OpenAITool(function=f) for f in tools] if tools is not None else None
# Add tool choice

View File

@@ -31,6 +31,11 @@ class LettaRequest(BaseModel):
default=None, description="Only return specified message types in the response. If `None` (default) returns all messages."
)
enable_thinking: str = Field(
default=True,
description="If set to True, enables reasoning before responses or tool calls from the agent.",
)
class LettaStreamingRequest(LettaRequest):
stream_tokens: bool = Field(

View File

@@ -1,5 +1,5 @@
{
"model": "gemini-2.5-flash-preview-04-17",
"model": "gemini-2.5-flash",
"model_endpoint_type": "google_vertex",
"model_endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/memgpt-428419/locations/us-central1",
"context_window": 1048576,

View File

@@ -1,5 +1,5 @@
{
"model": "gemini-2.5-pro-preview-05-06",
"model": "gemini-2.5-pro",
"model_endpoint_type": "google_vertex",
"model_endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/memgpt-428419/locations/us-central1",
"context_window": 1048576,

File diff suppressed because one or more lines are too long

View File

@@ -118,6 +118,12 @@ all_configs = [
"ollama.json", # TODO (cliandy): enable this in ollama testing
]
reasoning_configs = [
"openai-o1.json",
"openai-o3.json",
"openai-o4-mini.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
@@ -170,6 +176,43 @@ def assert_greeting_with_assistant_message_response(
assert messages[index].step_count > 0
def assert_greeting_no_reasoning_response(
messages: List[Any],
streaming: bool = False,
token_streaming: bool = False,
from_db: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence without reasoning:
AssistantMessage (no ReasoningMessage when put_inner_thoughts_in_kwargs is False).
"""
expected_message_count = 3 if streaming else 2 if from_db else 1
assert len(messages) == expected_message_count
index = 0
if from_db:
assert isinstance(messages[index], UserMessage)
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# Agent Step 1 - should be AssistantMessage directly, no reasoning
assert isinstance(messages[index], AssistantMessage)
if not token_streaming:
assert USER_MESSAGE_RESPONSE in messages[index].content
assert messages[index].otid and messages[index].otid[-1] == "0"
index += 1
if streaming:
assert isinstance(messages[index], LettaStopReason)
assert messages[index].stop_reason == "end_turn"
index += 1
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
assert messages[index].completion_tokens > 0
assert messages[index].total_tokens > 0
assert messages[index].step_count > 0
def assert_greeting_without_assistant_message_response(
messages: List[Any],
llm_config: LLMConfig,
@@ -463,10 +506,10 @@ def agent_state(client: Letta) -> AgentState:
)
yield agent_state_instance
try:
client.agents.delete(agent_state_instance.id)
except Exception as e:
logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
# try:
# client.agents.delete(agent_state_instance.id)
# except Exception as e:
# logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
@pytest.fixture(scope="function")
@@ -485,10 +528,10 @@ def agent_state_no_tools(client: Letta) -> AgentState:
)
yield agent_state_instance
try:
client.agents.delete(agent_state_instance.id)
except Exception as e:
logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
# try:
# client.agents.delete(agent_state_instance.id)
# except Exception as e:
# logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
# ------------------------------
@@ -1309,17 +1352,17 @@ def test_job_creation_for_send_message(
# TODO (cliandy): MERGE BACK IN POST
# @pytest.mark.parametrize(
# "llm_config",
# TESTED_LLM_CONFIGS,
# ids=[c.model for c in TESTED_LLM_CONFIGS],
# )
# def test_async_job_cancellation(
# disable_e2b_api_key: Any,
# client: Letta,
# agent_state: AgentState,
# llm_config: LLMConfig,
# ) -> None:
# # @pytest.mark.parametrize(
# # "llm_config",
# # TESTED_LLM_CONFIGS,
# # ids=[c.model for c in TESTED_LLM_CONFIGS],
# # )
# # def test_async_job_cancellation(
# # disable_e2b_api_key: Any,
# # client: Letta,
# # agent_state: AgentState,
# # llm_config: LLMConfig,
# # ) -> None:
# """
# Test that an async job can be cancelled and the cancellation is reflected in the job status.
# """
@@ -1457,3 +1500,86 @@ def test_job_creation_for_send_message(
#
# # This test primarily validates that the implementation doesn't break under simulated disconnection
# assert True # If we get here without errors, the architecture is sound
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_inner_thoughts_false_non_reasoner_models(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a reasoning model
if not config_filename or config_filename in reasoning_configs:
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")
# create a new config with all reasoning fields turned off
new_llm_config = llm_config.model_dump()
new_llm_config["put_inner_thoughts_in_kwargs"] = False
new_llm_config["enable_reasoner"] = False
new_llm_config["max_reasoning_tokens"] = 0
adjusted_llm_config = LLMConfig(**new_llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=adjusted_llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
assert_greeting_no_reasoning_response(response.messages)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_no_reasoning_response(messages_from_db, from_db=True)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_inner_thoughts_false_non_reasoner_models_streaming(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a reasoning model
if not config_filename or config_filename in reasoning_configs:
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")
# create a new config with all reasoning fields turned off
new_llm_config = llm_config.model_dump()
new_llm_config["put_inner_thoughts_in_kwargs"] = False
new_llm_config["enable_reasoner"] = False
new_llm_config["max_reasoning_tokens"] = 0
adjusted_llm_config = LLMConfig(**new_llm_config)
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=adjusted_llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_FORCE_REPLY,
)
messages = accumulate_chunks(list(response))
assert_greeting_no_reasoning_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_no_reasoning_response(messages_from_db, from_db=True)

View File

@@ -1169,12 +1169,10 @@ class TestAgentFileImportWithProcessing:
# When using Pinecone, status stays at embedding until chunks are confirmed uploaded
if should_use_pinecone():
assert imported_file.processing_status.value == "embedding"
assert imported_file.total_chunks == 1 # Pinecone tracks chunk counts
assert imported_file.chunks_embedded == 0
else:
assert imported_file.processing_status.value == "completed"
assert imported_file.total_chunks is None
assert imported_file.chunks_embedded is None
assert imported_file.total_chunks == 1 # Pinecone tracks chunk counts
assert imported_file.chunks_embedded == 0
async def test_import_multiple_files_processing(self, server, agent_serialization_manager, default_user, other_user):
"""Test import processes multiple files efficiently."""

View File

@@ -15,15 +15,6 @@ from letta.orm import FileMetadata, Source
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import (
AssistantMessage,
LettaMessage,
ReasoningMessage,
SystemMessage,
ToolCallMessage,
ToolReturnMessage,
UserMessage,
)
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.services.helpers.agent_manager_helper import initialize_message_sequence
@@ -145,29 +136,6 @@ def test_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
), "Memory update failed"
def test_agent_interactions(disable_e2b_api_key, client: RESTClient, agent: AgentState):
# test that it is a LettaMessage
message = "Hello again, agent!"
print("Sending message", message)
response = client.user_message(agent_id=agent.id, message=message)
assert all([isinstance(m, LettaMessage) for m in response.messages]), "All messages should be LettaMessages"
# We should also check that the types were cast properly
print("RESPONSE MESSAGES, client type:", type(client))
print(response.messages)
for letta_message in response.messages:
assert type(letta_message) in [
SystemMessage,
UserMessage,
ReasoningMessage,
ToolCallMessage,
ToolReturnMessage,
AssistantMessage,
], f"Unexpected message type: {type(letta_message)}"
# TODO: add streaming tests
def test_archival_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
# _reset_config()
@@ -202,14 +170,6 @@ def test_archival_memory(disable_e2b_api_key, client: RESTClient, agent: AgentSt
client.get_archival_memory(agent.id)
def test_core_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user")
print("Response", response)
memory = client.get_in_context_memory(agent_id=agent.id)
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"
def test_humans_personas(client: RESTClient, agent: AgentState):
# _reset_config()

View File

@@ -1,4 +1,5 @@
import asyncio
import json
import logging
import os
import random
@@ -7,6 +8,7 @@ import string
import time
from datetime import datetime, timedelta, timezone
from typing import List
from unittest.mock import Mock
# tests/test_file_content_flow.py
import pytest
@@ -37,7 +39,6 @@ from letta.constants import (
MULTI_AGENT_TOOLS,
)
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.embeddings import embedding_model
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.helpers import ToolRulesSolver
@@ -93,16 +94,7 @@ from letta.utils import calculate_file_defaults_based_on_context_window
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
from tests.utils import random_string
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
embedding_endpoint_type="hugging-face",
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_model="letta-free",
embedding_dim=1024,
embedding_chunk_size=300,
azure_endpoint=None,
azure_version=None,
azure_deployment=None,
)
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig.default_config(provider="openai")
CREATE_DELAY_SQLITE = 1
USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
@@ -2717,10 +2709,28 @@ async def test_agent_list_passages_filtering(server, default_user, sarah_agent,
assert len(date_filtered) == 5
@pytest.fixture
def mock_embeddings():
"""Load mock embeddings from JSON file"""
fixture_path = os.path.join(os.path.dirname(__file__), "data", "test_embeddings.json")
with open(fixture_path, "r") as f:
return json.load(f)
@pytest.fixture
def mock_embed_model(mock_embeddings):
"""Mock embedding model that returns predefined embeddings"""
mock_model = Mock()
mock_model.get_text_embedding = lambda text: mock_embeddings.get(text, [0.0] * 1536)
return mock_model
@pytest.mark.asyncio
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, default_file, event_loop):
async def test_agent_list_passages_vector_search(
server, default_user, sarah_agent, default_source, default_file, event_loop, mock_embed_model
):
"""Test vector search functionality of agent passages"""
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
embed_model = mock_embed_model
# Create passages with known embeddings
passages = []