feat: fallback to all mode for summarizer if error (#6465)
This commit is contained in:
committed by
Caren Thomas
parent
7fa141273d
commit
bd97b23025
@@ -9,7 +9,7 @@ These tests verify the complete summarization flow:
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -606,28 +606,35 @@ async def test_summarize_truncates_large_tool_return(server: SyncServer, actor,
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# SummarizerConfig Mode Tests (with pytest.patch)
|
||||
# SummarizerConfig Mode Tests (with pytest.patch) - Using LettaAgentV3
|
||||
# ======================================================================================================================
|
||||
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from unittest.mock import patch
|
||||
|
||||
SUMMARIZATION_MODES = [
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER,
|
||||
]
|
||||
from letta.services.summarizer.summarizer_config import SummarizerConfig, get_default_summarizer_config
|
||||
|
||||
# Test both summarizer modes: "all" summarizes entire history, "sliding_window" keeps recent messages
|
||||
SUMMARIZER_CONFIG_MODES: list[Literal["all", "sliding_window"]] = ["all", "sliding_window"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("mode", SUMMARIZATION_MODES, ids=[m.value for m in SUMMARIZATION_MODES])
|
||||
@pytest.mark.parametrize("mode", SUMMARIZER_CONFIG_MODES, ids=SUMMARIZER_CONFIG_MODES)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
|
||||
async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMConfig, mode: SummarizationMode):
|
||||
async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMConfig, mode: Literal["all", "sliding_window"]):
|
||||
"""
|
||||
Test summarization with different modes and LLM configurations.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
Test summarization with different SummarizerConfig modes using LettaAgentV3.
|
||||
|
||||
This test verifies that both summarization modes work correctly:
|
||||
- "all": Summarizes the entire conversation history into a single summary
|
||||
- "sliding_window": Keeps recent messages and summarizes older ones
|
||||
"""
|
||||
# Create a conversation with enough messages to trigger summarization
|
||||
messages = []
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.system,
|
||||
content=[TextContent(type="text", text="You are a helpful assistant.")],
|
||||
)
|
||||
]
|
||||
for i in range(10):
|
||||
messages.append(
|
||||
PydanticMessage(
|
||||
@@ -644,26 +651,73 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
|
||||
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
|
||||
with patch("letta.agents.letta_agent_v2.summarizer_settings") as mock_settings:
|
||||
mock_settings.mode = mode
|
||||
mock_settings.message_buffer_limit = 10
|
||||
mock_settings.message_buffer_min = 3
|
||||
mock_settings.partial_evict_summarizer_percentage = 0.30
|
||||
mock_settings.max_summarizer_retries = 3
|
||||
# Create new messages that would be added during this step
|
||||
new_letta_messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(type="text", text="This is a new user message during this step.")],
|
||||
agent_id=agent_state.id,
|
||||
)
|
||||
]
|
||||
# Persist the new messages
|
||||
new_letta_messages = await server.message_manager.create_many_messages_async(new_letta_messages, actor=actor)
|
||||
|
||||
agent_loop = LettaAgentV2(agent_state=agent_state, actor=actor)
|
||||
assert agent_loop.summarizer.mode == mode
|
||||
# Create a custom SummarizerConfig with the desired mode
|
||||
def mock_get_default_summarizer_config(model_settings):
|
||||
config = get_default_summarizer_config(model_settings)
|
||||
# Override the mode
|
||||
return SummarizerConfig(
|
||||
model_settings=config.model_settings,
|
||||
prompt=config.prompt,
|
||||
prompt_acknowledgement=config.prompt_acknowledgement,
|
||||
clip_chars=config.clip_chars,
|
||||
mode=mode,
|
||||
sliding_window_percentage=config.sliding_window_percentage,
|
||||
)
|
||||
|
||||
with patch("letta.agents.letta_agent_v3.get_default_summarizer_config", mock_get_default_summarizer_config):
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
result = await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=[],
|
||||
new_letta_messages=new_letta_messages,
|
||||
total_tokens=None,
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
print(f"{mode.value} with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")
|
||||
|
||||
# Verify that the result contains valid messages
|
||||
for msg in result:
|
||||
assert hasattr(msg, "role")
|
||||
assert hasattr(msg, "content")
|
||||
|
||||
print()
|
||||
print(f"RESULTS {mode} ======")
|
||||
for msg in result:
|
||||
print(f"MSG: {msg}")
|
||||
|
||||
print()
|
||||
|
||||
if mode == "all":
|
||||
# For "all" mode, result should be just the summary message
|
||||
assert len(result) == 2, f"Expected 1 message for 'all' mode, got {len(result)}"
|
||||
else:
|
||||
# For "sliding_window" mode, result should include recent messages + summary
|
||||
assert len(result) > 1, f"Expected >1 messages for 'sliding_window' mode, got {len(result)}"
|
||||
# validate new user message
|
||||
assert result[-1].role == MessageRole.user and result[-1].agent_id == agent_state.id, (
|
||||
f"Expected new user message with agent_id {agent_state.id}, got {result[-1]}"
|
||||
)
|
||||
assert "This is a new user message" in result[-1].content[0].text, (
|
||||
f"Expected 'This is a new user message' in the user message, got {result[-1]}"
|
||||
)
|
||||
|
||||
# validate system message
|
||||
assert result[0].role == MessageRole.system
|
||||
# validate summary message
|
||||
assert "prior messages" in result[1].content[0].text, f"Expected 'prior messages' in the summary message, got {result[1]}"
|
||||
print(f"Mode '{mode}' with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -1134,67 +1188,70 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server:
|
||||
# print(f"Total LLM invocations: {call_count}")
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# "llm_config",
|
||||
# TESTED_LLM_CONFIGS,
|
||||
# ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# )
|
||||
# async def test_summarize_all_with_real_llm(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
# """
|
||||
# Test the summarize_all function with real LLM calls.
|
||||
#
|
||||
# This test verifies that the 'all' summarization mode works correctly,
|
||||
# summarizing the entire conversation into a single summary string.
|
||||
# """
|
||||
# from letta.schemas.model import ModelSettings
|
||||
# from letta.services.summarizer.summarizer_all import summarize_all
|
||||
# from letta.services.summarizer.summarizer_config import get_default_summarizer_config
|
||||
#
|
||||
# # Create a summarizer config with "all" mode
|
||||
# model_settings = ModelSettings()
|
||||
# summarizer_config = get_default_summarizer_config(model_settings)
|
||||
# summarizer_config.mode = "all"
|
||||
#
|
||||
# # Create test messages - a simple conversation
|
||||
# messages = [
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.system,
|
||||
# content=[TextContent(type="text", text="You are a helpful assistant.")],
|
||||
# )
|
||||
# ]
|
||||
#
|
||||
# # Add 10 user-assistant pairs
|
||||
# for i in range(10):
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text=f"User message {i}: What is {i} + {i}?")],
|
||||
# )
|
||||
# )
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.assistant,
|
||||
# content=[TextContent(type="text", text=f"Assistant response {i}: {i} + {i} = {i * 2}.")],
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# assert len(messages) == 21, f"Expected 21 messages, got {len(messages)}"
|
||||
#
|
||||
# # Call summarize_all with real LLM
|
||||
# summary = await summarize_all(
|
||||
# actor=actor,
|
||||
# llm_config=llm_config,
|
||||
# summarizer_config=summarizer_config,
|
||||
# in_context_messages=messages,
|
||||
# new_messages=[],
|
||||
# )
|
||||
#
|
||||
# # Verify the summary was generated
|
||||
# assert summary is not None
|
||||
# assert len(summary) > 0
|
||||
#
|
||||
# print(f"Successfully summarized {len(messages)} messages using 'all' mode")
|
||||
# print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
|
||||
# print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}")
|
||||
#
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
async def test_summarize_all(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
"""
|
||||
Test the summarize_all function with real LLM calls.
|
||||
|
||||
This test verifies that the 'all' summarization mode works correctly,
|
||||
summarizing the entire conversation into a single summary string.
|
||||
"""
|
||||
from letta.schemas.model import ModelSettings
|
||||
from letta.services.summarizer.summarizer_all import summarize_all
|
||||
from letta.services.summarizer.summarizer_config import get_default_summarizer_config
|
||||
|
||||
# Create a summarizer config with "all" mode
|
||||
model_settings = ModelSettings()
|
||||
summarizer_config = get_default_summarizer_config(model_settings)
|
||||
summarizer_config.mode = "all"
|
||||
|
||||
# Create test messages - a simple conversation
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.system,
|
||||
content=[TextContent(type="text", text="You are a helpful assistant.")],
|
||||
)
|
||||
]
|
||||
|
||||
# Add 10 user-assistant pairs
|
||||
for i in range(10):
|
||||
messages.append(
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(type="text", text=f"User message {i}: What is {i} + {i}?")],
|
||||
)
|
||||
)
|
||||
messages.append(
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(type="text", text=f"Assistant response {i}: {i} + {i} = {i * 2}.")],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(messages) == 21, f"Expected 21 messages, got {len(messages)}"
|
||||
|
||||
# Call summarize_all with real LLM
|
||||
summary, new_in_context_messages = await summarize_all(
|
||||
actor=actor,
|
||||
llm_config=llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
|
||||
# Verify the summary was generated
|
||||
assert len(new_in_context_messages) == 0
|
||||
assert summary is not None
|
||||
assert len(summary) > 0
|
||||
assert len(summary) <= 2000
|
||||
|
||||
print(f"Successfully summarized {len(messages)} messages using 'all' mode")
|
||||
print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
|
||||
print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}")
|
||||
|
||||
Reference in New Issue
Block a user