fix: patch summarizer and add tests (#6457)
This commit is contained in:
committed by
Caren Thomas
parent
e67c98eedb
commit
c0b422c4c6
@@ -73,8 +73,8 @@ async def summarize_via_sliding_window(
|
||||
|
||||
# Starts at N% (eg 70%), and increments up until 100%
|
||||
message_count_cutoff_percent = max(
|
||||
1 - summarizer_config.sliding_window_percentage, 10
|
||||
) # Some arbitrary minimum value to avoid negatives from badly configured summarizer percentage
|
||||
1 - summarizer_config.sliding_window_percentage, 0.10
|
||||
) # Some arbitrary minimum value (10%) to avoid negatives from badly configured summarizer percentage
|
||||
found_cutoff = False
|
||||
|
||||
# Count tokens with system prompt, and message past cutoff point
|
||||
@@ -98,8 +98,8 @@ async def summarize_via_sliding_window(
|
||||
if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window:
|
||||
found_cutoff = True
|
||||
else:
|
||||
message_count_cutoff_percent += 10
|
||||
if message_count_cutoff_percent >= 100:
|
||||
message_count_cutoff_percent += 0.10
|
||||
if message_count_cutoff_percent >= 1.0:
|
||||
message_cutoff_index = total_message_count
|
||||
break
|
||||
|
||||
|
||||
@@ -663,3 +663,137 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
print(f"{mode.value} with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Sliding Window Summarizer Unit Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
|
||||
"""
|
||||
Test that the sliding window summarizer correctly calculates cutoff indices.
|
||||
|
||||
This test verifies the fix for a bug where the cutoff percentage was treated as
|
||||
a whole number (10) instead of a decimal (0.10), causing:
|
||||
message_cutoff_index = round(10 * 65) = 650
|
||||
when there were only 65 messages, resulting in an empty range loop and the error:
|
||||
"No assistant message found from indices 650 to 65"
|
||||
|
||||
The fix changed:
|
||||
- max(..., 10) -> max(..., 0.10)
|
||||
- += 10 -> += 0.10
|
||||
- >= 100 -> >= 1.0
|
||||
"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User
|
||||
from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window
|
||||
|
||||
# Create a mock user (using proper ID format pattern)
|
||||
mock_actor = User(
|
||||
id="user-00000000-0000-0000-0000-000000000000", name="Test User", organization_id="org-00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
|
||||
# Create a mock LLM config
|
||||
mock_llm_config = LLMConfig(
|
||||
model="gpt-4",
|
||||
model_endpoint_type="openai",
|
||||
context_window=128000,
|
||||
)
|
||||
|
||||
# Create a mock summarizer config with sliding_window_percentage = 0.3
|
||||
mock_summarizer_config = MagicMock()
|
||||
mock_summarizer_config.sliding_window_percentage = 0.3
|
||||
mock_summarizer_config.summarizer_model = mock_llm_config
|
||||
mock_summarizer_config.prompt = "Summarize the conversation."
|
||||
mock_summarizer_config.prompt_acknowledgement = True
|
||||
mock_summarizer_config.clip_chars = 2000
|
||||
|
||||
# Create 65 messages (similar to the failing case in the bug report)
|
||||
# Pattern: system + alternating user/assistant messages
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.system,
|
||||
content=[TextContent(type="text", text="You are a helpful assistant.")],
|
||||
)
|
||||
]
|
||||
|
||||
# Add 64 more messages (32 user-assistant pairs)
|
||||
for i in range(32):
|
||||
messages.append(
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(type="text", text=f"User message {i}")],
|
||||
)
|
||||
)
|
||||
messages.append(
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(type="text", text=f"Assistant response {i}")],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(messages) == 65, f"Expected 65 messages, got {len(messages)}"
|
||||
|
||||
# Mock count_tokens to return a value that would trigger summarization
|
||||
# Return a high token count so that the while loop continues
|
||||
async def mock_count_tokens(actor, llm_config, messages):
|
||||
# Return tokens that decrease as we cut off more messages
|
||||
# This simulates the token count decreasing as we evict messages
|
||||
return len(messages) * 100 # 100 tokens per message
|
||||
|
||||
# Mock simple_summary to return a fake summary
|
||||
async def mock_simple_summary(messages, llm_config, actor, include_ack, prompt):
|
||||
return "This is a mock summary of the conversation."
|
||||
|
||||
with (
|
||||
patch(
|
||||
"letta.services.summarizer.summarizer_sliding_window.count_tokens",
|
||||
side_effect=mock_count_tokens,
|
||||
),
|
||||
patch(
|
||||
"letta.services.summarizer.summarizer_sliding_window.simple_summary",
|
||||
side_effect=mock_simple_summary,
|
||||
),
|
||||
):
|
||||
# This should NOT raise "No assistant message found from indices 650 to 65"
|
||||
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
|
||||
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
|
||||
try:
|
||||
summary, remaining_messages = await summarize_via_sliding_window(
|
||||
actor=mock_actor,
|
||||
llm_config=mock_llm_config,
|
||||
summarizer_config=mock_summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
|
||||
# Verify the summary was generated
|
||||
assert summary == "This is a mock summary of the conversation."
|
||||
|
||||
# Verify remaining messages is a valid subset
|
||||
assert len(remaining_messages) < len(messages)
|
||||
assert len(remaining_messages) > 0
|
||||
|
||||
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
|
||||
|
||||
except ValueError as e:
|
||||
if "No assistant message found from indices" in str(e):
|
||||
# Extract the indices from the error message
|
||||
import re
|
||||
|
||||
match = re.search(r"from indices (\d+) to (\d+)", str(e))
|
||||
if match:
|
||||
start_idx, end_idx = int(match.group(1)), int(match.group(2))
|
||||
pytest.fail(
|
||||
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
|
||||
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user