feat: Scrub inner thoughts from history on toggle (#3607)

This commit is contained in:
Matthew Zhou
2025-07-28 21:43:36 -07:00
committed by GitHub
parent 04511d1ffc
commit 272b36c63f
3 changed files with 221 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import AsyncTimer, get_utc_time, get_utc_timestamp_ns, ns_to_ms
from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
@@ -1169,6 +1170,9 @@ class LettaAgent(BaseAgent):
tool_rules_solver=tool_rules_solver,
)
# scrub inner thoughts from messages if reasoning is completely disabled
in_context_messages = scrub_inner_thoughts_from_messages(in_context_messages, agent_state.llm_config)
tools = [
t
for t in agent_state.tools

View File

@@ -0,0 +1,48 @@
from typing import List
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
def is_reasoning_completely_disabled(llm_config: LLMConfig) -> bool:
"""
Check if reasoning is completely disabled by verifying all three conditions:
- put_inner_thoughts_in_kwargs is False
- enable_reasoner is False
- max_reasoning_tokens is 0
Args:
llm_config: The LLM configuration to check
Returns:
True if reasoning is completely disabled, False otherwise
"""
return llm_config.put_inner_thoughts_in_kwargs is False and llm_config.enable_reasoner is False and llm_config.max_reasoning_tokens == 0
def scrub_inner_thoughts_from_messages(messages: List[Message], llm_config: LLMConfig) -> List[Message]:
"""
Remove inner thoughts (reasoning text) from assistant messages when reasoning is completely disabled.
This makes the LLM think reasoning was never enabled by presenting clean message history.
Args:
messages: List of messages to potentially scrub
llm_config: The LLM configuration to check
Returns:
The message list with inner thoughts removed if reasoning is disabled, otherwise unchanged
"""
# early return if reasoning is not completely disabled
if not is_reasoning_completely_disabled(llm_config):
return messages
# process messages to remove inner thoughts from assistant messages
for message in messages:
if message.role == MessageRole.assistant and message.content and message.tool_calls:
# remove text content from assistant messages that also have tool calls
# keep only non-text content (if any)
message.content = [content for content in message.content if not isinstance(content, TextContent)]
return messages

View File

@@ -12,7 +12,7 @@ import httpx
import pytest
import requests
from dotenv import load_dotenv
from letta_client import AsyncLetta, Letta, MessageCreate, Run
from letta_client import AsyncLetta, Letta, LettaRequest, MessageCreate, Run
from letta_client.core.api_error import ApiError
from letta_client.types import (
AssistantMessage,
@@ -30,6 +30,7 @@ from letta_client.types import (
UserMessage,
)
from letta.helpers.reasoning_helper import is_reasoning_completely_disabled
from letta.llm_api.openai_client import is_openai_reasoning_model
from letta.log import get_logger
from letta.schemas.agent import AgentState
@@ -71,6 +72,13 @@ USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_GREETING: List[MessageCreate] = [
MessageCreate(
role="user",
content=f"Hi!",
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
MessageCreate(
role="user",
@@ -330,6 +338,99 @@ def assert_tool_call_response(
assert messages[index].step_count > 0
def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
"""
Validate that OpenAI format assistant messages with tool calls have no inner thoughts content.
Args:
messages: List of message dictionaries in OpenAI format
"""
assistant_messages_with_tools = []
for msg in messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
assistant_messages_with_tools.append(msg)
# There should be at least one assistant message with tool calls
assert len(assistant_messages_with_tools) > 0, "Expected at least one OpenAI assistant message with tool calls"
# Check that assistant messages with tool calls have no text content (inner thoughts scrubbed)
for msg in assistant_messages_with_tools:
if "content" in msg:
content = msg["content"]
assert content is None
def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
"""
Validate that Anthropic/Claude format assistant messages with tool_use have no <thinking> tags.
Args:
messages: List of message dictionaries in Anthropic format
"""
claude_assistant_messages_with_tools = []
for msg in messages:
if (
msg.get("role") == "assistant"
and isinstance(msg.get("content"), list)
and any(item.get("type") == "tool_use" for item in msg.get("content", []))
):
claude_assistant_messages_with_tools.append(msg)
# There should be at least one Claude assistant message with tool_use
assert len(claude_assistant_messages_with_tools) > 0, "Expected at least one Claude assistant message with tool_use"
# Check Claude format messages specifically
for msg in claude_assistant_messages_with_tools:
content_list = msg["content"]
# Strict validation: assistant messages with tool_use should have NO text content items at all
text_items = [item for item in content_list if item.get("type") == "text"]
assert len(text_items) == 0, (
f"Found {len(text_items)} text content item(s) in Claude assistant message with tool_use. "
f"When reasoning is disabled, there should be NO text items. "
f"Text items found: {[item.get('text', '') for item in text_items]}"
)
# Verify that the message only contains tool_use items
tool_use_items = [item for item in content_list if item.get("type") == "tool_use"]
assert len(tool_use_items) > 0, "Assistant message should have at least one tool_use item"
assert len(content_list) == len(tool_use_items), (
f"Assistant message should ONLY contain tool_use items when reasoning is disabled. "
f"Found {len(content_list)} total items but only {len(tool_use_items)} are tool_use items."
)
def validate_google_format_scrubbing(contents: List[Dict[str, Any]]) -> None:
"""
Validate that Google/Gemini format model messages with functionCall have no thinking field.
Args:
contents: List of content dictionaries in Google format (uses 'contents' instead of 'messages')
"""
model_messages_with_function_calls = []
for content in contents:
if content.get("role") == "model" and isinstance(content.get("parts"), list):
for part in content["parts"]:
if "functionCall" in part:
model_messages_with_function_calls.append(part)
# There should be at least one model message with functionCall
assert len(model_messages_with_function_calls) > 0, "Expected at least one Google model message with functionCall"
# Check Google format messages specifically
for part in model_messages_with_function_calls:
function_call = part["functionCall"]
args = function_call.get("args", {})
# Assert that there is no 'thinking' field in the function call arguments
assert (
"thinking" not in args
), f"Found 'thinking' field in Google model functionCall args (inner thoughts not scrubbed): {args.get('thinking')}"
def assert_image_input_response(
messages: List[Any],
llm_config: LLMConfig,
@@ -1583,3 +1684,70 @@ def test_inner_thoughts_false_non_reasoner_models_streaming(
assert_greeting_no_reasoning_response(messages, streaming=True)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_greeting_no_reasoning_response(messages_from_db, from_db=True)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_inner_thoughts_toggle_interleaved(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
# get the config filename
config_filename = None
for filename in filenames:
config = get_llm_config(filename)
if config.model_dump() == llm_config.model_dump():
config_filename = filename
break
# skip if this is a reasoning model
if not config_filename or config_filename in reasoning_configs:
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")
# Only run on OpenAI, Anthropic, and Google models
if llm_config.model_endpoint_type not in ["openai", "anthropic", "google_ai", "google_vertex"]:
pytest.skip(f"Skipping `test_inner_thoughts_toggle_interleaved` for model endpoint type {llm_config.model_endpoint_type}")
assert not is_reasoning_completely_disabled(llm_config), "Reasoning should be enabled"
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
# Send a message with inner thoughts
client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
)
# create a new config with all reasoning fields turned off
new_llm_config = llm_config.model_dump()
new_llm_config["put_inner_thoughts_in_kwargs"] = False
new_llm_config["enable_reasoner"] = False
new_llm_config["max_reasoning_tokens"] = 0
adjusted_llm_config = LLMConfig(**new_llm_config)
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=adjusted_llm_config)
# Preview the message payload of the next message
response = client.agents.messages.preview_raw_payload(
agent_id=agent_state.id,
request=LettaRequest(messages=USER_MESSAGE_FORCE_REPLY),
)
# Test our helper functions
assert is_reasoning_completely_disabled(adjusted_llm_config), "Reasoning should be completely disabled"
# Verify that assistant messages with tool calls have been scrubbed of inner thoughts
# Branch assertions based on model endpoint type
if llm_config.model_endpoint_type == "openai":
messages = response["messages"]
validate_openai_format_scrubbing(messages)
elif llm_config.model_endpoint_type == "anthropic":
messages = response["messages"]
validate_anthropic_format_scrubbing(messages)
elif llm_config.model_endpoint_type in ["google_ai", "google_vertex"]:
# Google uses 'contents' instead of 'messages'
contents = response.get("contents", response.get("messages", []))
validate_google_format_scrubbing(contents)