fix: patch summarizer and add tests (#6457)

This commit is contained in:
Sarah Wooders
2025-11-29 20:01:40 -08:00
committed by Caren Thomas
parent e67c98eedb
commit c0b422c4c6
2 changed files with 138 additions and 4 deletions

View File

@@ -73,8 +73,8 @@ async def summarize_via_sliding_window(
# Starts at N% (eg 70%), and increments up until 100%
message_count_cutoff_percent = max(
1 - summarizer_config.sliding_window_percentage, 10
) # Some arbitrary minimum value to avoid negatives from badly configured summarizer percentage
1 - summarizer_config.sliding_window_percentage, 0.10
) # Some arbitrary minimum value (10%) to avoid negatives from badly configured summarizer percentage
found_cutoff = False
# Count tokens with system prompt, and message past cutoff point
@@ -98,8 +98,8 @@ async def summarize_via_sliding_window(
if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window:
found_cutoff = True
else:
message_count_cutoff_percent += 10
if message_count_cutoff_percent >= 100:
message_count_cutoff_percent += 0.10
if message_count_cutoff_percent >= 1.0:
message_cutoff_index = total_message_count
break

View File

@@ -663,3 +663,137 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
assert isinstance(result, list)
assert len(result) >= 1
print(f"{mode.value} with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")
# ======================================================================================================================
# Sliding Window Summarizer Unit Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
"""
Test that the sliding window summarizer correctly calculates cutoff indices.
This test verifies the fix for a bug where the cutoff percentage was treated as
a whole number (10) instead of a decimal (0.10), causing:
message_cutoff_index = round(10 * 65) = 650
when there were only 65 messages, resulting in an empty range loop and the error:
"No assistant message found from indices 650 to 65"
The fix changed:
- max(..., 10) -> max(..., 0.10)
- += 10 -> += 0.10
- >= 100 -> >= 1.0
"""
from unittest.mock import MagicMock, patch
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User
from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window
# Create a mock user (using proper ID format pattern)
mock_actor = User(
id="user-00000000-0000-0000-0000-000000000000", name="Test User", organization_id="org-00000000-0000-0000-0000-000000000000"
)
# Create a mock LLM config
mock_llm_config = LLMConfig(
model="gpt-4",
model_endpoint_type="openai",
context_window=128000,
)
# Create a mock summarizer config with sliding_window_percentage = 0.3
mock_summarizer_config = MagicMock()
mock_summarizer_config.sliding_window_percentage = 0.3
mock_summarizer_config.summarizer_model = mock_llm_config
mock_summarizer_config.prompt = "Summarize the conversation."
mock_summarizer_config.prompt_acknowledgement = True
mock_summarizer_config.clip_chars = 2000
# Create 65 messages (similar to the failing case in the bug report)
# Pattern: system + alternating user/assistant messages
messages = [
PydanticMessage(
role=MessageRole.system,
content=[TextContent(type="text", text="You are a helpful assistant.")],
)
]
# Add 64 more messages (32 user-assistant pairs)
for i in range(32):
messages.append(
PydanticMessage(
role=MessageRole.user,
content=[TextContent(type="text", text=f"User message {i}")],
)
)
messages.append(
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(type="text", text=f"Assistant response {i}")],
)
)
assert len(messages) == 65, f"Expected 65 messages, got {len(messages)}"
# Mock count_tokens to return a value that would trigger summarization
# Return a high token count so that the while loop continues
async def mock_count_tokens(actor, llm_config, messages):
# Return tokens that decrease as we cut off more messages
# This simulates the token count decreasing as we evict messages
return len(messages) * 100 # 100 tokens per message
# Mock simple_summary to return a fake summary
async def mock_simple_summary(messages, llm_config, actor, include_ack, prompt):
return "This is a mock summary of the conversation."
with (
patch(
"letta.services.summarizer.summarizer_sliding_window.count_tokens",
side_effect=mock_count_tokens,
),
patch(
"letta.services.summarizer.summarizer_sliding_window.simple_summary",
side_effect=mock_simple_summary,
),
):
# This should NOT raise "No assistant message found from indices 650 to 65"
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
try:
summary, remaining_messages = await summarize_via_sliding_window(
actor=mock_actor,
llm_config=mock_llm_config,
summarizer_config=mock_summarizer_config,
in_context_messages=messages,
new_messages=[],
)
# Verify the summary was generated
assert summary == "This is a mock summary of the conversation."
# Verify remaining messages is a valid subset
assert len(remaining_messages) < len(messages)
assert len(remaining_messages) > 0
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
except ValueError as e:
if "No assistant message found from indices" in str(e):
# Extract the indices from the error message
import re
match = re.search(r"from indices (\d+) to (\d+)", str(e))
if match:
start_idx, end_idx = int(match.group(1)), int(match.group(2))
pytest.fail(
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
f"Error: {e}"
)
raise