feat: Add ability to disable reasoning (#3594)
This commit is contained in:
@@ -215,11 +215,7 @@ class AnthropicClient(LLMClientBase):
|
||||
)
|
||||
llm_config.put_inner_thoughts_in_kwargs = True
|
||||
else:
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
# tool_choice_type other than "auto" only plays nice if thinking goes inside the tool calls
|
||||
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
|
||||
else:
|
||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools] if tools is not None else None
|
||||
|
||||
# Add tool choice
|
||||
|
||||
@@ -31,6 +31,11 @@ class LettaRequest(BaseModel):
|
||||
default=None, description="Only return specified message types in the response. If `None` (default) returns all messages."
|
||||
)
|
||||
|
||||
enable_thinking: str = Field(
|
||||
default=True,
|
||||
description="If set to True, enables reasoning before responses or tool calls from the agent.",
|
||||
)
|
||||
|
||||
|
||||
class LettaStreamingRequest(LettaRequest):
|
||||
stream_tokens: bool = Field(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"model": "gemini-2.5-flash-preview-04-17",
|
||||
"model": "gemini-2.5-flash",
|
||||
"model_endpoint_type": "google_vertex",
|
||||
"model_endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/memgpt-428419/locations/us-central1",
|
||||
"context_window": 1048576,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"model": "gemini-2.5-pro-preview-05-06",
|
||||
"model": "gemini-2.5-pro",
|
||||
"model_endpoint_type": "google_vertex",
|
||||
"model_endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/memgpt-428419/locations/us-central1",
|
||||
"context_window": 1048576,
|
||||
|
||||
7
tests/data/test_embeddings.json
Normal file
7
tests/data/test_embeddings.json
Normal file
File diff suppressed because one or more lines are too long
@@ -118,6 +118,12 @@ all_configs = [
|
||||
"ollama.json", # TODO (cliandy): enable this in ollama testing
|
||||
]
|
||||
|
||||
reasoning_configs = [
|
||||
"openai-o1.json",
|
||||
"openai-o3.json",
|
||||
"openai-o4-mini.json",
|
||||
]
|
||||
|
||||
|
||||
requested = os.getenv("LLM_CONFIG_FILE")
|
||||
filenames = [requested] if requested else all_configs
|
||||
@@ -170,6 +176,43 @@ def assert_greeting_with_assistant_message_response(
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def assert_greeting_no_reasoning_response(
|
||||
messages: List[Any],
|
||||
streaming: bool = False,
|
||||
token_streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence without reasoning:
|
||||
AssistantMessage (no ReasoningMessage when put_inner_thoughts_in_kwargs is False).
|
||||
"""
|
||||
expected_message_count = 3 if streaming else 2 if from_db else 1
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
index = 0
|
||||
if from_db:
|
||||
assert isinstance(messages[index], UserMessage)
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# Agent Step 1 - should be AssistantMessage directly, no reasoning
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
if not token_streaming:
|
||||
assert USER_MESSAGE_RESPONSE in messages[index].content
|
||||
assert messages[index].otid and messages[index].otid[-1] == "0"
|
||||
index += 1
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaStopReason)
|
||||
assert messages[index].stop_reason == "end_turn"
|
||||
index += 1
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
assert messages[index].completion_tokens > 0
|
||||
assert messages[index].total_tokens > 0
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def assert_greeting_without_assistant_message_response(
|
||||
messages: List[Any],
|
||||
llm_config: LLMConfig,
|
||||
@@ -463,10 +506,10 @@ def agent_state(client: Letta) -> AgentState:
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
try:
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
|
||||
# try:
|
||||
# client.agents.delete(agent_state_instance.id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -485,10 +528,10 @@ def agent_state_no_tools(client: Letta) -> AgentState:
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
try:
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
|
||||
# try:
|
||||
# client.agents.delete(agent_state_instance.id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@@ -1309,17 +1352,17 @@ def test_job_creation_for_send_message(
|
||||
|
||||
|
||||
# TODO (cliandy): MERGE BACK IN POST
|
||||
# @pytest.mark.parametrize(
|
||||
# "llm_config",
|
||||
# TESTED_LLM_CONFIGS,
|
||||
# ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# )
|
||||
# def test_async_job_cancellation(
|
||||
# disable_e2b_api_key: Any,
|
||||
# client: Letta,
|
||||
# agent_state: AgentState,
|
||||
# llm_config: LLMConfig,
|
||||
# ) -> None:
|
||||
# # @pytest.mark.parametrize(
|
||||
# # "llm_config",
|
||||
# # TESTED_LLM_CONFIGS,
|
||||
# # ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# # )
|
||||
# # def test_async_job_cancellation(
|
||||
# # disable_e2b_api_key: Any,
|
||||
# # client: Letta,
|
||||
# # agent_state: AgentState,
|
||||
# # llm_config: LLMConfig,
|
||||
# # ) -> None:
|
||||
# """
|
||||
# Test that an async job can be cancelled and the cancellation is reflected in the job status.
|
||||
# """
|
||||
@@ -1457,3 +1500,86 @@ def test_job_creation_for_send_message(
|
||||
#
|
||||
# # This test primarily validates that the implementation doesn't break under simulated disconnection
|
||||
# assert True # If we get here without errors, the architecture is sound
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_inner_thoughts_false_non_reasoner_models(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
# get the config filename
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a reasoning model
|
||||
if not config_filename or config_filename in reasoning_configs:
|
||||
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")
|
||||
|
||||
# create a new config with all reasoning fields turned off
|
||||
new_llm_config = llm_config.model_dump()
|
||||
new_llm_config["put_inner_thoughts_in_kwargs"] = False
|
||||
new_llm_config["enable_reasoner"] = False
|
||||
new_llm_config["max_reasoning_tokens"] = 0
|
||||
adjusted_llm_config = LLMConfig(**new_llm_config)
|
||||
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=adjusted_llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
assert_greeting_no_reasoning_response(response.messages)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_no_reasoning_response(messages_from_db, from_db=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_inner_thoughts_false_non_reasoner_models_streaming(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
# get the config filename
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a reasoning model
|
||||
if not config_filename or config_filename in reasoning_configs:
|
||||
pytest.skip(f"Skipping test for reasoning model {llm_config.model}")
|
||||
|
||||
# create a new config with all reasoning fields turned off
|
||||
new_llm_config = llm_config.model_dump()
|
||||
new_llm_config["put_inner_thoughts_in_kwargs"] = False
|
||||
new_llm_config["enable_reasoner"] = False
|
||||
new_llm_config["max_reasoning_tokens"] = 0
|
||||
adjusted_llm_config = LLMConfig(**new_llm_config)
|
||||
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=adjusted_llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
messages = accumulate_chunks(list(response))
|
||||
assert_greeting_no_reasoning_response(messages, streaming=True)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_no_reasoning_response(messages_from_db, from_db=True)
|
||||
|
||||
@@ -1169,12 +1169,10 @@ class TestAgentFileImportWithProcessing:
|
||||
# When using Pinecone, status stays at embedding until chunks are confirmed uploaded
|
||||
if should_use_pinecone():
|
||||
assert imported_file.processing_status.value == "embedding"
|
||||
assert imported_file.total_chunks == 1 # Pinecone tracks chunk counts
|
||||
assert imported_file.chunks_embedded == 0
|
||||
else:
|
||||
assert imported_file.processing_status.value == "completed"
|
||||
assert imported_file.total_chunks is None
|
||||
assert imported_file.chunks_embedded is None
|
||||
assert imported_file.total_chunks == 1 # Pinecone tracks chunk counts
|
||||
assert imported_file.chunks_embedded == 0
|
||||
|
||||
async def test_import_multiple_files_processing(self, server, agent_serialization_manager, default_user, other_user):
|
||||
"""Test import processes multiple files efficiently."""
|
||||
|
||||
@@ -15,15 +15,6 @@ from letta.orm import FileMetadata, Source
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
LettaMessage,
|
||||
ReasoningMessage,
|
||||
SystemMessage,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.services.helpers.agent_manager_helper import initialize_message_sequence
|
||||
@@ -145,29 +136,6 @@ def test_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
|
||||
), "Memory update failed"
|
||||
|
||||
|
||||
def test_agent_interactions(disable_e2b_api_key, client: RESTClient, agent: AgentState):
|
||||
# test that it is a LettaMessage
|
||||
message = "Hello again, agent!"
|
||||
print("Sending message", message)
|
||||
response = client.user_message(agent_id=agent.id, message=message)
|
||||
assert all([isinstance(m, LettaMessage) for m in response.messages]), "All messages should be LettaMessages"
|
||||
|
||||
# We should also check that the types were cast properly
|
||||
print("RESPONSE MESSAGES, client type:", type(client))
|
||||
print(response.messages)
|
||||
for letta_message in response.messages:
|
||||
assert type(letta_message) in [
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
ReasoningMessage,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
AssistantMessage,
|
||||
], f"Unexpected message type: {type(letta_message)}"
|
||||
|
||||
# TODO: add streaming tests
|
||||
|
||||
|
||||
def test_archival_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
@@ -202,14 +170,6 @@ def test_archival_memory(disable_e2b_api_key, client: RESTClient, agent: AgentSt
|
||||
client.get_archival_memory(agent.id)
|
||||
|
||||
|
||||
def test_core_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
|
||||
response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user")
|
||||
print("Response", response)
|
||||
|
||||
memory = client.get_in_context_memory(agent_id=agent.id)
|
||||
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"
|
||||
|
||||
|
||||
def test_humans_personas(client: RESTClient, agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -7,6 +8,7 @@ import string
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
# tests/test_file_content_flow.py
|
||||
import pytest
|
||||
@@ -37,7 +39,6 @@ from letta.constants import (
|
||||
MULTI_AGENT_TOOLS,
|
||||
)
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.helpers import ToolRulesSolver
|
||||
@@ -93,16 +94,7 @@ from letta.utils import calculate_file_defaults_based_on_context_window
|
||||
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
|
||||
from tests.utils import random_string
|
||||
|
||||
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
|
||||
embedding_endpoint_type="hugging-face",
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_model="letta-free",
|
||||
embedding_dim=1024,
|
||||
embedding_chunk_size=300,
|
||||
azure_endpoint=None,
|
||||
azure_version=None,
|
||||
azure_deployment=None,
|
||||
)
|
||||
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig.default_config(provider="openai")
|
||||
CREATE_DELAY_SQLITE = 1
|
||||
USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
|
||||
|
||||
@@ -2717,10 +2709,28 @@ async def test_agent_list_passages_filtering(server, default_user, sarah_agent,
|
||||
assert len(date_filtered) == 5
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings():
|
||||
"""Load mock embeddings from JSON file"""
|
||||
fixture_path = os.path.join(os.path.dirname(__file__), "data", "test_embeddings.json")
|
||||
with open(fixture_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model(mock_embeddings):
|
||||
"""Mock embedding model that returns predefined embeddings"""
|
||||
mock_model = Mock()
|
||||
mock_model.get_text_embedding = lambda text: mock_embeddings.get(text, [0.0] * 1536)
|
||||
return mock_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, default_file, event_loop):
|
||||
async def test_agent_list_passages_vector_search(
|
||||
server, default_user, sarah_agent, default_source, default_file, event_loop, mock_embed_model
|
||||
):
|
||||
"""Test vector search functionality of agent passages"""
|
||||
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
||||
embed_model = mock_embed_model
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
|
||||
Reference in New Issue
Block a user