feat: fix new summarizer code and add more tests (#6461)
This commit is contained in:
committed by
Caren Thomas
parent
86023db9b1
commit
91e3dd8b3e
@@ -14,6 +14,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -671,7 +672,12 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
"""
|
||||
Test that the sliding window summarizer correctly calculates cutoff indices.
|
||||
|
||||
@@ -685,35 +691,19 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
|
||||
- max(..., 10) -> max(..., 0.10)
|
||||
- += 10 -> += 0.10
|
||||
- >= 100 -> >= 1.0
|
||||
"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User
|
||||
This test uses the real token counter (via create_token_counter) to verify
|
||||
the sliding window logic works with actual token counting.
|
||||
"""
|
||||
from letta.schemas.model import ModelSettings
|
||||
from letta.services.summarizer.summarizer_config import get_default_summarizer_config
|
||||
from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window
|
||||
|
||||
# Create a mock user (using proper ID format pattern)
|
||||
mock_actor = User(
|
||||
id="user-00000000-0000-0000-0000-000000000000", name="Test User", organization_id="org-00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
|
||||
# Create a mock LLM config
|
||||
mock_llm_config = LLMConfig(
|
||||
model="gpt-4",
|
||||
model_endpoint_type="openai",
|
||||
context_window=128000,
|
||||
)
|
||||
|
||||
# Create a mock summarizer config with sliding_window_percentage = 0.3
|
||||
mock_summarizer_config = MagicMock()
|
||||
mock_summarizer_config.sliding_window_percentage = 0.3
|
||||
mock_summarizer_config.summarizer_model = mock_llm_config
|
||||
mock_summarizer_config.prompt = "Summarize the conversation."
|
||||
mock_summarizer_config.prompt_acknowledgement = True
|
||||
mock_summarizer_config.clip_chars = 2000
|
||||
# Create a real summarizer config using the default factory
|
||||
# Override sliding_window_percentage to 0.3 for this test
|
||||
model_settings = ModelSettings() # Use defaults
|
||||
summarizer_config = get_default_summarizer_config(model_settings)
|
||||
summarizer_config.sliding_window_percentage = 0.3
|
||||
|
||||
# Create 65 messages (similar to the failing case in the bug report)
|
||||
# Pattern: system + alternating user/assistant messages
|
||||
@@ -741,59 +731,470 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
|
||||
|
||||
assert len(messages) == 65, f"Expected 65 messages, got {len(messages)}"
|
||||
|
||||
# Mock count_tokens to return a value that would trigger summarization
|
||||
# Return a high token count so that the while loop continues
|
||||
async def mock_count_tokens(actor, llm_config, messages):
|
||||
# Return tokens that decrease as we cut off more messages
|
||||
# This simulates the token count decreasing as we evict messages
|
||||
return len(messages) * 100 # 100 tokens per message
|
||||
# This should NOT raise "No assistant message found from indices 650 to 65"
|
||||
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
|
||||
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
|
||||
try:
|
||||
summary, remaining_messages = await summarize_via_sliding_window(
|
||||
actor=actor,
|
||||
llm_config=llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
|
||||
# Mock simple_summary to return a fake summary
|
||||
async def mock_simple_summary(messages, llm_config, actor, include_ack, prompt):
|
||||
return "This is a mock summary of the conversation."
|
||||
# Verify the summary was generated (actual LLM response)
|
||||
assert summary is not None
|
||||
assert len(summary) > 0
|
||||
|
||||
with (
|
||||
patch(
|
||||
"letta.services.summarizer.summarizer_sliding_window.count_tokens",
|
||||
side_effect=mock_count_tokens,
|
||||
),
|
||||
patch(
|
||||
"letta.services.summarizer.summarizer_sliding_window.simple_summary",
|
||||
side_effect=mock_simple_summary,
|
||||
),
|
||||
):
|
||||
# This should NOT raise "No assistant message found from indices 650 to 65"
|
||||
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
|
||||
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
|
||||
try:
|
||||
summary, remaining_messages = await summarize_via_sliding_window(
|
||||
actor=mock_actor,
|
||||
llm_config=mock_llm_config,
|
||||
summarizer_config=mock_summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
# Verify remaining messages is a valid subset
|
||||
assert len(remaining_messages) < len(messages)
|
||||
assert len(remaining_messages) > 0
|
||||
|
||||
# Verify the summary was generated
|
||||
assert summary == "This is a mock summary of the conversation."
|
||||
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
|
||||
print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
|
||||
print(f"Using {llm_config.model_endpoint_type} token counter for model {llm_config.model}")
|
||||
|
||||
# Verify remaining messages is a valid subset
|
||||
assert len(remaining_messages) < len(messages)
|
||||
assert len(remaining_messages) > 0
|
||||
except ValueError as e:
|
||||
if "No assistant message found from indices" in str(e):
|
||||
# Extract the indices from the error message
|
||||
import re
|
||||
|
||||
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
|
||||
match = re.search(r"from indices (\d+) to (\d+)", str(e))
|
||||
if match:
|
||||
start_idx, end_idx = int(match.group(1)), int(match.group(2))
|
||||
pytest.fail(
|
||||
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
|
||||
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
except ValueError as e:
|
||||
if "No assistant message found from indices" in str(e):
|
||||
# Extract the indices from the error message
|
||||
import re
|
||||
|
||||
match = re.search(r"from indices (\d+) to (\d+)", str(e))
|
||||
if match:
|
||||
start_idx, end_idx = int(match.group(1)), int(match.group(2))
|
||||
pytest.fail(
|
||||
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
|
||||
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor):
|
||||
# """
|
||||
# Test that a ContextWindowExceededError during a streaming LLM request
|
||||
# properly triggers the summarizer and compacts the in-context messages.
|
||||
#
|
||||
# This test simulates:
|
||||
# 1. An LLM streaming request that fails with ContextWindowExceededError
|
||||
# 2. The summarizer being invoked to reduce context size
|
||||
# 3. Verification that messages are compacted and summary message exists
|
||||
#
|
||||
# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling.
|
||||
# """
|
||||
# import uuid
|
||||
# from unittest.mock import patch
|
||||
#
|
||||
# import openai
|
||||
#
|
||||
# from letta.schemas.message import MessageCreate
|
||||
# from letta.schemas.run import Run
|
||||
# from letta.services.run_manager import RunManager
|
||||
#
|
||||
# # Use OpenAI config for this test (since we're using OpenAI-specific error handling)
|
||||
# llm_config = get_llm_config("openai-gpt-4o-mini.json")
|
||||
#
|
||||
# # Create test messages - enough to have something to summarize
|
||||
# messages = []
|
||||
# for i in range(15):
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")],
|
||||
# )
|
||||
# )
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.assistant,
|
||||
# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")],
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
# original_message_count = len(agent_state.message_ids)
|
||||
#
|
||||
# # Create an input message to trigger the agent
|
||||
# input_message = MessageCreate(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text="Hello, please respond.")],
|
||||
# )
|
||||
#
|
||||
# # Create a proper run record in the database
|
||||
# run_manager = RunManager()
|
||||
# test_run_id = f"run-{uuid.uuid4()}"
|
||||
# test_run = Run(
|
||||
# id=test_run_id,
|
||||
# agent_id=agent_state.id,
|
||||
# )
|
||||
# await run_manager.create_run(test_run, actor)
|
||||
#
|
||||
# # Create the agent loop using LettaAgentV3
|
||||
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
#
|
||||
# # Track how many times stream_async is called
|
||||
# call_count = 0
|
||||
#
|
||||
# # Store original stream_async method
|
||||
# original_stream_async = agent_loop.llm_client.stream_async
|
||||
#
|
||||
# async def mock_stream_async_with_error(request_data, llm_config):
|
||||
# nonlocal call_count
|
||||
# call_count += 1
|
||||
# if call_count == 1:
|
||||
# # First call raises OpenAI BadRequestError with context_length_exceeded error code
|
||||
# # This will be properly converted to ContextWindowExceededError by handle_llm_error
|
||||
# from unittest.mock import MagicMock
|
||||
#
|
||||
# import httpx
|
||||
#
|
||||
# # Create a mock response with the required structure
|
||||
# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
|
||||
# mock_response = httpx.Response(
|
||||
# status_code=400,
|
||||
# request=mock_request,
|
||||
# json={
|
||||
# "error": {
|
||||
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# "type": "invalid_request_error",
|
||||
# "code": "context_length_exceeded",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
#
|
||||
# raise openai.BadRequestError(
|
||||
# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# response=mock_response,
|
||||
# body={
|
||||
# "error": {
|
||||
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# "type": "invalid_request_error",
|
||||
# "code": "context_length_exceeded",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
# # Subsequent calls use the real implementation
|
||||
# return await original_stream_async(request_data, llm_config)
|
||||
#
|
||||
# # Patch the llm_client's stream_async to raise ContextWindowExceededError on first call
|
||||
# with patch.object(agent_loop.llm_client, "stream_async", side_effect=mock_stream_async_with_error):
|
||||
# # Execute a streaming step
|
||||
# try:
|
||||
# result_chunks = []
|
||||
# async for chunk in agent_loop.stream(
|
||||
# input_messages=[input_message],
|
||||
# max_steps=1,
|
||||
# stream_tokens=True,
|
||||
# run_id=test_run_id,
|
||||
# ):
|
||||
# result_chunks.append(chunk)
|
||||
# except Exception as e:
|
||||
# # Some errors might happen due to real LLM calls after retry
|
||||
# print(f"Exception during stream: {e}")
|
||||
#
|
||||
# # Reload agent state to get updated message_ids after summarization
|
||||
# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
|
||||
# updated_message_count = len(updated_agent_state.message_ids)
|
||||
#
|
||||
# # Fetch the updated in-context messages
|
||||
# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async(
|
||||
# message_ids=updated_agent_state.message_ids, actor=actor
|
||||
# )
|
||||
#
|
||||
# # Convert to LettaMessage format for easier content inspection
|
||||
# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages)
|
||||
#
|
||||
# # Verify a summary message exists with the correct format
|
||||
# # The summary message has content with type="system_alert" and message containing:
|
||||
# # "prior messages ... have been hidden" and "summary of the previous"
|
||||
# import json
|
||||
#
|
||||
# summary_message_found = False
|
||||
# summary_message_text = None
|
||||
# for msg in letta_messages:
|
||||
# # Not all message types have a content attribute (e.g., ReasoningMessage)
|
||||
# if not hasattr(msg, "content"):
|
||||
# continue
|
||||
#
|
||||
# content = msg.content
|
||||
# # Content can be a string (JSON) or an object with type/message fields
|
||||
# if isinstance(content, str):
|
||||
# # Try to parse as JSON
|
||||
# try:
|
||||
# parsed = json.loads(content)
|
||||
# if isinstance(parsed, dict) and parsed.get("type") == "system_alert":
|
||||
# text_to_check = parsed.get("message", "").lower()
|
||||
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
|
||||
# summary_message_found = True
|
||||
# summary_message_text = parsed.get("message")
|
||||
# break
|
||||
# except (json.JSONDecodeError, TypeError):
|
||||
# pass
|
||||
# # Check if content has system_alert type with the summary message (object form)
|
||||
# elif hasattr(content, "type") and content.type == "system_alert":
|
||||
# if hasattr(content, "message") and content.message:
|
||||
# text_to_check = content.message.lower()
|
||||
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
|
||||
# summary_message_found = True
|
||||
# summary_message_text = content.message
|
||||
# break
|
||||
#
|
||||
# assert summary_message_found, (
|
||||
# "A summary message should exist in the in-context messages after summarization. "
|
||||
# "Expected format containing 'prior messages...hidden' and 'summary of the previous'"
|
||||
# )
|
||||
#
|
||||
# # Verify we attempted multiple invocations (the failing one + retry after summarization)
|
||||
# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}"
|
||||
#
|
||||
# # The original messages should have been compacted - the updated count should be less than
|
||||
# # original + the new messages added (input + assistant response + tool results)
|
||||
# # Since summarization should have removed most of the original 30 messages
|
||||
# print("Test passed: Summary message found in context")
|
||||
# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}")
|
||||
# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...")
|
||||
# print(f"Total LLM invocations: {call_count}")
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_context_window_overflow_triggers_summarization_in_blocking(server: SyncServer, actor):
|
||||
# """
|
||||
# Test that a ContextWindowExceededError during a blocking (non-streaming) LLM request
|
||||
# properly triggers the summarizer and compacts the in-context messages.
|
||||
#
|
||||
# This test is similar to the streaming test but uses the blocking step() method.
|
||||
#
|
||||
# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling.
|
||||
# """
|
||||
# import uuid
|
||||
# from unittest.mock import patch
|
||||
#
|
||||
# import openai
|
||||
#
|
||||
# from letta.schemas.message import MessageCreate
|
||||
# from letta.schemas.run import Run
|
||||
# from letta.services.run_manager import RunManager
|
||||
#
|
||||
# # Use OpenAI config for this test (since we're using OpenAI-specific error handling)
|
||||
# llm_config = get_llm_config("openai-gpt-4o-mini.json")
|
||||
#
|
||||
# # Create test messages
|
||||
# messages = []
|
||||
# for i in range(15):
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")],
|
||||
# )
|
||||
# )
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.assistant,
|
||||
# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")],
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
# original_message_count = len(agent_state.message_ids)
|
||||
#
|
||||
# # Create an input message to trigger the agent
|
||||
# input_message = MessageCreate(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text="Hello, please respond.")],
|
||||
# )
|
||||
#
|
||||
# # Create a proper run record in the database
|
||||
# run_manager = RunManager()
|
||||
# test_run_id = f"run-{uuid.uuid4()}"
|
||||
# test_run = Run(
|
||||
# id=test_run_id,
|
||||
# agent_id=agent_state.id,
|
||||
# )
|
||||
# await run_manager.create_run(test_run, actor)
|
||||
#
|
||||
# # Create the agent loop using LettaAgentV3
|
||||
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
#
|
||||
# # Track how many times request_async is called
|
||||
# call_count = 0
|
||||
#
|
||||
# # Store original request_async method
|
||||
# original_request_async = agent_loop.llm_client.request_async
|
||||
#
|
||||
# async def mock_request_async_with_error(request_data, llm_config):
|
||||
# nonlocal call_count
|
||||
# call_count += 1
|
||||
# if call_count == 1:
|
||||
# # First call raises OpenAI BadRequestError with context_length_exceeded error code
|
||||
# # This will be properly converted to ContextWindowExceededError by handle_llm_error
|
||||
# import httpx
|
||||
#
|
||||
# # Create a mock response with the required structure
|
||||
# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
|
||||
# mock_response = httpx.Response(
|
||||
# status_code=400,
|
||||
# request=mock_request,
|
||||
# json={
|
||||
# "error": {
|
||||
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# "type": "invalid_request_error",
|
||||
# "code": "context_length_exceeded",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
#
|
||||
# raise openai.BadRequestError(
|
||||
# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# response=mock_response,
|
||||
# body={
|
||||
# "error": {
|
||||
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# "type": "invalid_request_error",
|
||||
# "code": "context_length_exceeded",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
# # Subsequent calls use the real implementation
|
||||
# return await original_request_async(request_data, llm_config)
|
||||
#
|
||||
# # Patch the llm_client's request_async to raise ContextWindowExceededError on first call
|
||||
# with patch.object(agent_loop.llm_client, "request_async", side_effect=mock_request_async_with_error):
|
||||
# # Execute a blocking step
|
||||
# try:
|
||||
# result = await agent_loop.step(
|
||||
# input_messages=[input_message],
|
||||
# max_steps=1,
|
||||
# run_id=test_run_id,
|
||||
# )
|
||||
# except Exception as e:
|
||||
# # Some errors might happen due to real LLM calls after retry
|
||||
# print(f"Exception during step: {e}")
|
||||
#
|
||||
# # Reload agent state to get updated message_ids after summarization
|
||||
# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
|
||||
# updated_message_count = len(updated_agent_state.message_ids)
|
||||
#
|
||||
# # Fetch the updated in-context messages
|
||||
# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async(
|
||||
# message_ids=updated_agent_state.message_ids, actor=actor
|
||||
# )
|
||||
#
|
||||
# # Convert to LettaMessage format for easier content inspection
|
||||
# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages)
|
||||
#
|
||||
# # Verify a summary message exists with the correct format
|
||||
# # The summary message has content with type="system_alert" and message containing:
|
||||
# # "prior messages ... have been hidden" and "summary of the previous"
|
||||
# import json
|
||||
#
|
||||
# summary_message_found = False
|
||||
# summary_message_text = None
|
||||
# for msg in letta_messages:
|
||||
# # Not all message types have a content attribute (e.g., ReasoningMessage)
|
||||
# if not hasattr(msg, "content"):
|
||||
# continue
|
||||
#
|
||||
# content = msg.content
|
||||
# # Content can be a string (JSON) or an object with type/message fields
|
||||
# if isinstance(content, str):
|
||||
# # Try to parse as JSON
|
||||
# try:
|
||||
# parsed = json.loads(content)
|
||||
# if isinstance(parsed, dict) and parsed.get("type") == "system_alert":
|
||||
# text_to_check = parsed.get("message", "").lower()
|
||||
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
|
||||
# summary_message_found = True
|
||||
# summary_message_text = parsed.get("message")
|
||||
# break
|
||||
# except (json.JSONDecodeError, TypeError):
|
||||
# pass
|
||||
# # Check if content has system_alert type with the summary message (object form)
|
||||
# elif hasattr(content, "type") and content.type == "system_alert":
|
||||
# if hasattr(content, "message") and content.message:
|
||||
# text_to_check = content.message.lower()
|
||||
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
|
||||
# summary_message_found = True
|
||||
# summary_message_text = content.message
|
||||
# break
|
||||
#
|
||||
# assert summary_message_found, (
|
||||
# "A summary message should exist in the in-context messages after summarization. "
|
||||
# "Expected format containing 'prior messages...hidden' and 'summary of the previous'"
|
||||
# )
|
||||
#
|
||||
# # Verify we attempted multiple invocations (the failing one + retry after summarization)
|
||||
# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}"
|
||||
#
|
||||
# # The original messages should have been compacted - the updated count should be less than
|
||||
# # original + the new messages added (input + assistant response + tool results)
|
||||
# print("Test passed: Summary message found in context (blocking mode)")
|
||||
# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}")
|
||||
# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...")
|
||||
# print(f"Total LLM invocations: {call_count}")
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# "llm_config",
|
||||
# TESTED_LLM_CONFIGS,
|
||||
# ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# )
|
||||
# async def test_summarize_all_with_real_llm(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
# """
|
||||
# Test the summarize_all function with real LLM calls.
|
||||
#
|
||||
# This test verifies that the 'all' summarization mode works correctly,
|
||||
# summarizing the entire conversation into a single summary string.
|
||||
# """
|
||||
# from letta.schemas.model import ModelSettings
|
||||
# from letta.services.summarizer.summarizer_all import summarize_all
|
||||
# from letta.services.summarizer.summarizer_config import get_default_summarizer_config
|
||||
#
|
||||
# # Create a summarizer config with "all" mode
|
||||
# model_settings = ModelSettings()
|
||||
# summarizer_config = get_default_summarizer_config(model_settings)
|
||||
# summarizer_config.mode = "all"
|
||||
#
|
||||
# # Create test messages - a simple conversation
|
||||
# messages = [
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.system,
|
||||
# content=[TextContent(type="text", text="You are a helpful assistant.")],
|
||||
# )
|
||||
# ]
|
||||
#
|
||||
# # Add 10 user-assistant pairs
|
||||
# for i in range(10):
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text=f"User message {i}: What is {i} + {i}?")],
|
||||
# )
|
||||
# )
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.assistant,
|
||||
# content=[TextContent(type="text", text=f"Assistant response {i}: {i} + {i} = {i * 2}.")],
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# assert len(messages) == 21, f"Expected 21 messages, got {len(messages)}"
|
||||
#
|
||||
# # Call summarize_all with real LLM
|
||||
# summary = await summarize_all(
|
||||
# actor=actor,
|
||||
# llm_config=llm_config,
|
||||
# summarizer_config=summarizer_config,
|
||||
# in_context_messages=messages,
|
||||
# new_messages=[],
|
||||
# )
|
||||
#
|
||||
# # Verify the summary was generated
|
||||
# assert summary is not None
|
||||
# assert len(summary) > 0
|
||||
#
|
||||
# print(f"Successfully summarized {len(messages)} messages using 'all' mode")
|
||||
# print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
|
||||
# print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}")
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user