feat: fix new summarizer code and add more tests (#6461)

This commit is contained in:
Sarah Wooders
2025-11-30 00:49:38 -08:00
committed by Caren Thomas
parent 86023db9b1
commit 91e3dd8b3e
25 changed files with 728 additions and 358 deletions

View File

@@ -14,6 +14,7 @@ from typing import List
import pytest
from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.agents.letta_agent_v3 import LettaAgentV3
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
@@ -671,7 +672,12 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
@pytest.mark.asyncio
async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server: SyncServer, actor, llm_config: LLMConfig):
"""
Test that the sliding window summarizer correctly calculates cutoff indices.
@@ -685,35 +691,19 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
- max(..., 10) -> max(..., 0.10)
- += 10 -> += 0.10
- >= 100 -> >= 1.0
"""
from unittest.mock import MagicMock, patch
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User
This test uses the real token counter (via create_token_counter) to verify
the sliding window logic works with actual token counting.
"""
from letta.schemas.model import ModelSettings
from letta.services.summarizer.summarizer_config import get_default_summarizer_config
from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window
# Create a mock user (using proper ID format pattern)
mock_actor = User(
id="user-00000000-0000-0000-0000-000000000000", name="Test User", organization_id="org-00000000-0000-0000-0000-000000000000"
)
# Create a mock LLM config
mock_llm_config = LLMConfig(
model="gpt-4",
model_endpoint_type="openai",
context_window=128000,
)
# Create a mock summarizer config with sliding_window_percentage = 0.3
mock_summarizer_config = MagicMock()
mock_summarizer_config.sliding_window_percentage = 0.3
mock_summarizer_config.summarizer_model = mock_llm_config
mock_summarizer_config.prompt = "Summarize the conversation."
mock_summarizer_config.prompt_acknowledgement = True
mock_summarizer_config.clip_chars = 2000
# Create a real summarizer config using the default factory
# Override sliding_window_percentage to 0.3 for this test
model_settings = ModelSettings() # Use defaults
summarizer_config = get_default_summarizer_config(model_settings)
summarizer_config.sliding_window_percentage = 0.3
# Create 65 messages (similar to the failing case in the bug report)
# Pattern: system + alternating user/assistant messages
@@ -741,59 +731,470 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
assert len(messages) == 65, f"Expected 65 messages, got {len(messages)}"
# Mock count_tokens to return a value that would trigger summarization
# Return a high token count so that the while loop continues
async def mock_count_tokens(actor, llm_config, messages):
# Return tokens that decrease as we cut off more messages
# This simulates the token count decreasing as we evict messages
return len(messages) * 100 # 100 tokens per message
# This should NOT raise "No assistant message found from indices 650 to 65"
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
try:
summary, remaining_messages = await summarize_via_sliding_window(
actor=actor,
llm_config=llm_config,
summarizer_config=summarizer_config,
in_context_messages=messages,
new_messages=[],
)
# Mock simple_summary to return a fake summary
async def mock_simple_summary(messages, llm_config, actor, include_ack, prompt):
return "This is a mock summary of the conversation."
# Verify the summary was generated (actual LLM response)
assert summary is not None
assert len(summary) > 0
with (
patch(
"letta.services.summarizer.summarizer_sliding_window.count_tokens",
side_effect=mock_count_tokens,
),
patch(
"letta.services.summarizer.summarizer_sliding_window.simple_summary",
side_effect=mock_simple_summary,
),
):
# This should NOT raise "No assistant message found from indices 650 to 65"
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
try:
summary, remaining_messages = await summarize_via_sliding_window(
actor=mock_actor,
llm_config=mock_llm_config,
summarizer_config=mock_summarizer_config,
in_context_messages=messages,
new_messages=[],
)
# Verify remaining messages is a valid subset
assert len(remaining_messages) < len(messages)
assert len(remaining_messages) > 0
# Verify the summary was generated
assert summary == "This is a mock summary of the conversation."
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
print(f"Using {llm_config.model_endpoint_type} token counter for model {llm_config.model}")
# Verify remaining messages is a valid subset
assert len(remaining_messages) < len(messages)
assert len(remaining_messages) > 0
except ValueError as e:
if "No assistant message found from indices" in str(e):
# Extract the indices from the error message
import re
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
match = re.search(r"from indices (\d+) to (\d+)", str(e))
if match:
start_idx, end_idx = int(match.group(1)), int(match.group(2))
pytest.fail(
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
f"Error: {e}"
)
raise
except ValueError as e:
if "No assistant message found from indices" in str(e):
# Extract the indices from the error message
import re
match = re.search(r"from indices (\d+) to (\d+)", str(e))
if match:
start_idx, end_idx = int(match.group(1)), int(match.group(2))
pytest.fail(
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
f"Error: {e}"
)
raise
# @pytest.mark.asyncio
# async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor):
# """
# Test that a ContextWindowExceededError during a streaming LLM request
# properly triggers the summarizer and compacts the in-context messages.
#
# This test simulates:
# 1. An LLM streaming request that fails with ContextWindowExceededError
# 2. The summarizer being invoked to reduce context size
# 3. Verification that messages are compacted and summary message exists
#
# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling.
# """
# import uuid
# from unittest.mock import patch
#
# import openai
#
# from letta.schemas.message import MessageCreate
# from letta.schemas.run import Run
# from letta.services.run_manager import RunManager
#
# # Use OpenAI config for this test (since we're using OpenAI-specific error handling)
# llm_config = get_llm_config("openai-gpt-4o-mini.json")
#
# # Create test messages - enough to have something to summarize
# messages = []
# for i in range(15):
# messages.append(
# PydanticMessage(
# role=MessageRole.user,
# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")],
# )
# )
# messages.append(
# PydanticMessage(
# role=MessageRole.assistant,
# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")],
# )
# )
#
# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
# original_message_count = len(agent_state.message_ids)
#
# # Create an input message to trigger the agent
# input_message = MessageCreate(
# role=MessageRole.user,
# content=[TextContent(type="text", text="Hello, please respond.")],
# )
#
# # Create a proper run record in the database
# run_manager = RunManager()
# test_run_id = f"run-{uuid.uuid4()}"
# test_run = Run(
# id=test_run_id,
# agent_id=agent_state.id,
# )
# await run_manager.create_run(test_run, actor)
#
# # Create the agent loop using LettaAgentV3
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
#
# # Track how many times stream_async is called
# call_count = 0
#
# # Store original stream_async method
# original_stream_async = agent_loop.llm_client.stream_async
#
# async def mock_stream_async_with_error(request_data, llm_config):
# nonlocal call_count
# call_count += 1
# if call_count == 1:
# # First call raises OpenAI BadRequestError with context_length_exceeded error code
# # This will be properly converted to ContextWindowExceededError by handle_llm_error
# from unittest.mock import MagicMock
#
# import httpx
#
# # Create a mock response with the required structure
# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
# mock_response = httpx.Response(
# status_code=400,
# request=mock_request,
# json={
# "error": {
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# "type": "invalid_request_error",
# "code": "context_length_exceeded",
# }
# },
# )
#
# raise openai.BadRequestError(
# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# response=mock_response,
# body={
# "error": {
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# "type": "invalid_request_error",
# "code": "context_length_exceeded",
# }
# },
# )
# # Subsequent calls use the real implementation
# return await original_stream_async(request_data, llm_config)
#
# # Patch the llm_client's stream_async to raise ContextWindowExceededError on first call
# with patch.object(agent_loop.llm_client, "stream_async", side_effect=mock_stream_async_with_error):
# # Execute a streaming step
# try:
# result_chunks = []
# async for chunk in agent_loop.stream(
# input_messages=[input_message],
# max_steps=1,
# stream_tokens=True,
# run_id=test_run_id,
# ):
# result_chunks.append(chunk)
# except Exception as e:
# # Some errors might happen due to real LLM calls after retry
# print(f"Exception during stream: {e}")
#
# # Reload agent state to get updated message_ids after summarization
# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
# updated_message_count = len(updated_agent_state.message_ids)
#
# # Fetch the updated in-context messages
# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async(
# message_ids=updated_agent_state.message_ids, actor=actor
# )
#
# # Convert to LettaMessage format for easier content inspection
# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages)
#
# # Verify a summary message exists with the correct format
# # The summary message has content with type="system_alert" and message containing:
# # "prior messages ... have been hidden" and "summary of the previous"
# import json
#
# summary_message_found = False
# summary_message_text = None
# for msg in letta_messages:
# # Not all message types have a content attribute (e.g., ReasoningMessage)
# if not hasattr(msg, "content"):
# continue
#
# content = msg.content
# # Content can be a string (JSON) or an object with type/message fields
# if isinstance(content, str):
# # Try to parse as JSON
# try:
# parsed = json.loads(content)
# if isinstance(parsed, dict) and parsed.get("type") == "system_alert":
# text_to_check = parsed.get("message", "").lower()
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
# summary_message_found = True
# summary_message_text = parsed.get("message")
# break
# except (json.JSONDecodeError, TypeError):
# pass
# # Check if content has system_alert type with the summary message (object form)
# elif hasattr(content, "type") and content.type == "system_alert":
# if hasattr(content, "message") and content.message:
# text_to_check = content.message.lower()
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
# summary_message_found = True
# summary_message_text = content.message
# break
#
# assert summary_message_found, (
# "A summary message should exist in the in-context messages after summarization. "
# "Expected format containing 'prior messages...hidden' and 'summary of the previous'"
# )
#
# # Verify we attempted multiple invocations (the failing one + retry after summarization)
# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}"
#
# # The original messages should have been compacted - the updated count should be less than
# # original + the new messages added (input + assistant response + tool results)
# # Since summarization should have removed most of the original 30 messages
# print("Test passed: Summary message found in context")
# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}")
# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...")
# print(f"Total LLM invocations: {call_count}")
#
#
# @pytest.mark.asyncio
# async def test_context_window_overflow_triggers_summarization_in_blocking(server: SyncServer, actor):
# """
# Test that a ContextWindowExceededError during a blocking (non-streaming) LLM request
# properly triggers the summarizer and compacts the in-context messages.
#
# This test is similar to the streaming test but uses the blocking step() method.
#
# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling.
# """
# import uuid
# from unittest.mock import patch
#
# import openai
#
# from letta.schemas.message import MessageCreate
# from letta.schemas.run import Run
# from letta.services.run_manager import RunManager
#
# # Use OpenAI config for this test (since we're using OpenAI-specific error handling)
# llm_config = get_llm_config("openai-gpt-4o-mini.json")
#
# # Create test messages
# messages = []
# for i in range(15):
# messages.append(
# PydanticMessage(
# role=MessageRole.user,
# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")],
# )
# )
# messages.append(
# PydanticMessage(
# role=MessageRole.assistant,
# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")],
# )
# )
#
# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
# original_message_count = len(agent_state.message_ids)
#
# # Create an input message to trigger the agent
# input_message = MessageCreate(
# role=MessageRole.user,
# content=[TextContent(type="text", text="Hello, please respond.")],
# )
#
# # Create a proper run record in the database
# run_manager = RunManager()
# test_run_id = f"run-{uuid.uuid4()}"
# test_run = Run(
# id=test_run_id,
# agent_id=agent_state.id,
# )
# await run_manager.create_run(test_run, actor)
#
# # Create the agent loop using LettaAgentV3
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
#
# # Track how many times request_async is called
# call_count = 0
#
# # Store original request_async method
# original_request_async = agent_loop.llm_client.request_async
#
# async def mock_request_async_with_error(request_data, llm_config):
# nonlocal call_count
# call_count += 1
# if call_count == 1:
# # First call raises OpenAI BadRequestError with context_length_exceeded error code
# # This will be properly converted to ContextWindowExceededError by handle_llm_error
# import httpx
#
# # Create a mock response with the required structure
# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
# mock_response = httpx.Response(
# status_code=400,
# request=mock_request,
# json={
# "error": {
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# "type": "invalid_request_error",
# "code": "context_length_exceeded",
# }
# },
# )
#
# raise openai.BadRequestError(
# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# response=mock_response,
# body={
# "error": {
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# "type": "invalid_request_error",
# "code": "context_length_exceeded",
# }
# },
# )
# # Subsequent calls use the real implementation
# return await original_request_async(request_data, llm_config)
#
# # Patch the llm_client's request_async to raise ContextWindowExceededError on first call
# with patch.object(agent_loop.llm_client, "request_async", side_effect=mock_request_async_with_error):
# # Execute a blocking step
# try:
# result = await agent_loop.step(
# input_messages=[input_message],
# max_steps=1,
# run_id=test_run_id,
# )
# except Exception as e:
# # Some errors might happen due to real LLM calls after retry
# print(f"Exception during step: {e}")
#
# # Reload agent state to get updated message_ids after summarization
# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
# updated_message_count = len(updated_agent_state.message_ids)
#
# # Fetch the updated in-context messages
# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async(
# message_ids=updated_agent_state.message_ids, actor=actor
# )
#
# # Convert to LettaMessage format for easier content inspection
# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages)
#
# # Verify a summary message exists with the correct format
# # The summary message has content with type="system_alert" and message containing:
# # "prior messages ... have been hidden" and "summary of the previous"
# import json
#
# summary_message_found = False
# summary_message_text = None
# for msg in letta_messages:
# # Not all message types have a content attribute (e.g., ReasoningMessage)
# if not hasattr(msg, "content"):
# continue
#
# content = msg.content
# # Content can be a string (JSON) or an object with type/message fields
# if isinstance(content, str):
# # Try to parse as JSON
# try:
# parsed = json.loads(content)
# if isinstance(parsed, dict) and parsed.get("type") == "system_alert":
# text_to_check = parsed.get("message", "").lower()
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
# summary_message_found = True
# summary_message_text = parsed.get("message")
# break
# except (json.JSONDecodeError, TypeError):
# pass
# # Check if content has system_alert type with the summary message (object form)
# elif hasattr(content, "type") and content.type == "system_alert":
# if hasattr(content, "message") and content.message:
# text_to_check = content.message.lower()
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
# summary_message_found = True
# summary_message_text = content.message
# break
#
# assert summary_message_found, (
# "A summary message should exist in the in-context messages after summarization. "
# "Expected format containing 'prior messages...hidden' and 'summary of the previous'"
# )
#
# # Verify we attempted multiple invocations (the failing one + retry after summarization)
# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}"
#
# # The original messages should have been compacted - the updated count should be less than
# # original + the new messages added (input + assistant response + tool results)
# print("Test passed: Summary message found in context (blocking mode)")
# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}")
# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...")
# print(f"Total LLM invocations: {call_count}")
#
#
# @pytest.mark.asyncio
# @pytest.mark.parametrize(
# "llm_config",
# TESTED_LLM_CONFIGS,
# ids=[c.model for c in TESTED_LLM_CONFIGS],
# )
# async def test_summarize_all_with_real_llm(server: SyncServer, actor, llm_config: LLMConfig):
# """
# Test the summarize_all function with real LLM calls.
#
# This test verifies that the 'all' summarization mode works correctly,
# summarizing the entire conversation into a single summary string.
# """
# from letta.schemas.model import ModelSettings
# from letta.services.summarizer.summarizer_all import summarize_all
# from letta.services.summarizer.summarizer_config import get_default_summarizer_config
#
# # Create a summarizer config with "all" mode
# model_settings = ModelSettings()
# summarizer_config = get_default_summarizer_config(model_settings)
# summarizer_config.mode = "all"
#
# # Create test messages - a simple conversation
# messages = [
# PydanticMessage(
# role=MessageRole.system,
# content=[TextContent(type="text", text="You are a helpful assistant.")],
# )
# ]
#
# # Add 10 user-assistant pairs
# for i in range(10):
# messages.append(
# PydanticMessage(
# role=MessageRole.user,
# content=[TextContent(type="text", text=f"User message {i}: What is {i} + {i}?")],
# )
# )
# messages.append(
# PydanticMessage(
# role=MessageRole.assistant,
# content=[TextContent(type="text", text=f"Assistant response {i}: {i} + {i} = {i * 2}.")],
# )
# )
#
# assert len(messages) == 21, f"Expected 21 messages, got {len(messages)}"
#
# # Call summarize_all with real LLM
# summary = await summarize_all(
# actor=actor,
# llm_config=llm_config,
# summarizer_config=summarizer_config,
# in_context_messages=messages,
# new_messages=[],
# )
#
# # Verify the summary was generated
# assert summary is not None
# assert len(summary) > 0
#
# print(f"Successfully summarized {len(messages)} messages using 'all' mode")
# print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
# print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}")
#