Add modes self and self_sliding_window for prompt caching (#9372)

* add self compaction method with proper caching (pass in tools, don't refresh sys prompt beforehand) + sliding fallback

* updated prompts for self compaction

* add tests for self, self_sliding_window modes and w/o refresh messages before compaction

* add cache logging to summarization

* better handling to prevent agent from continuing convo on self modes

* if mode changes via summarize endpoint, will use default prompt for the new mode

---------

Co-authored-by: Amy Guan <amy@letta.com>
This commit is contained in:
amysguan
2026-02-24 10:15:36 -08:00
committed by Caren Thomas
parent 47d55362a4
commit 47b0c87ebe
15 changed files with 1065 additions and 223 deletions

View File

@@ -15,17 +15,14 @@ import pytest
from letta.agents.letta_agent_v3 import LettaAgentV3
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.block import BlockUpdate, CreateBlock
from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import EventMessage, SummaryMessage
from letta.schemas.letta_message_content import TextContent, ToolCallContent, ToolReturnContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage, MessageCreate
from letta.schemas.run import Run as PydanticRun
from letta.schemas.message import Message as PydanticMessage
from letta.server.server import SyncServer
from letta.services.run_manager import RunManager
from letta.services.summarizer.summarizer import simple_summary
from letta.settings import model_settings
@@ -669,14 +666,24 @@ from unittest.mock import patch
from letta.services.summarizer.summarizer_config import CompactionSettings
# Test both summarizer modes: "all" summarizes entire history, "sliding_window" keeps recent messages
SUMMARIZER_CONFIG_MODES: list[Literal["all", "sliding_window"]] = ["all", "sliding_window"]
# Test all summarizer modes: "all" summarizes entire history, "sliding_window" keeps recent messages
SUMMARIZER_CONFIG_MODES: list[Literal["all", "sliding_window", "self_compact_all", "self_compact_sliding_window"]] = [
"all",
"sliding_window",
"self_compact_all",
"self_compact_sliding_window",
]
@pytest.mark.asyncio
@pytest.mark.parametrize("mode", SUMMARIZER_CONFIG_MODES, ids=SUMMARIZER_CONFIG_MODES)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMConfig, mode: Literal["all", "sliding_window"]):
async def test_summarize_with_mode(
server: SyncServer,
actor,
llm_config: LLMConfig,
mode: Literal["all", "sliding_window", "self_compact_all", "self_compact_sliding_window"],
):
"""
Test summarization with different CompactionSettings modes using LettaAgentV3.
@@ -746,20 +753,20 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
print()
if mode == "all":
# For "all" mode, V3 keeps:
if mode == "all" or mode == "self_compact_all":
# For "all" or "self" mode, V3 keeps:
# 1. System prompt
# 2. A single user summary message (system_alert JSON)
# and no remaining historical messages.
assert len(result) == 2, f"Expected 2 messages for 'all' mode (system + summary), got {len(result)}"
assert len(result) == 2, f"Expected 2 messages for {mode} mode (system + summary), got {len(result)}"
assert result[0].role == MessageRole.system
assert result[1].role == MessageRole.user
else:
# For "sliding_window" mode, result should include:
# For "sliding_window" or "self_compact_sliding_window" mode, result should include:
# 1. System prompt
# 2. User summary message
# 3+. Recent user/assistant messages inside the window.
assert len(result) > 2, f"Expected >2 messages for 'sliding_window' mode, got {len(result)}"
assert len(result) > 2, f"Expected >2 messages for {mode} mode, got {len(result)}"
assert result[0].role == MessageRole.system
assert result[1].role == MessageRole.user
@@ -1195,97 +1202,206 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server:
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
async def test_large_system_prompt_summarization(server: SyncServer, actor, llm_config: LLMConfig):
async def test_self_sliding_window_cutoff_index_does_not_exceed_message_count(server: SyncServer, actor, llm_config: LLMConfig):
"""
Test edge case of large system prompt / memory blocks.
Test that the sliding window summarizer correctly calculates cutoff indices.
This test verifies that summarization handles the case where the system prompt
and memory blocks are very large, potentially consuming most of the context window.
The summarizer should gracefully handle this scenario without errors.
This test verifies the fix for a bug where the cutoff percentage was treated as
a whole number (10) instead of a decimal (0.10), causing:
message_cutoff_index = round(10 * 65) = 650
when there were only 65 messages, resulting in an empty range loop and the error:
"No assistant message found from indices 650 to 65"
The fix changed:
- max(..., 10) -> max(..., 0.10)
- += 10 -> += 0.10
- >= 100 -> >= 1.0
This test uses the real token counter (via create_token_counter) to verify
the sliding window logic works with actual token counting.
"""
from letta.llm_api.llm_client import LLMClient
from letta.schemas.agent import AgentType
from letta.services.summarizer.self_summarizer import self_summarize_sliding_window
from letta.services.summarizer.summarizer_config import CompactionSettings
from letta.services.telemetry_manager import TelemetryManager
# Override context window to be small so we trigger summarization
llm_config.context_window = 10000
# Create a real summarizer config using the default factory
# Override sliding_window_percentage to 0.3 for this test
handle = llm_config.handle or f"{llm_config.model_endpoint_type}/{llm_config.model}"
summarizer_config = CompactionSettings(model=handle)
summarizer_config.sliding_window_percentage = 0.3
# Create agent with large system prompt and memory blocks
agent_name = f"test_agent_large_system_prompt_{llm_config.model}".replace(".", "_").replace("/", "_")
agent_create = CreateAgent(
name=agent_name,
llm_config=llm_config,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
system="SYSTEM PROMPT " * 10000, # Large system prompt
memory_blocks=[
CreateBlock(
label="human",
limit=200000,
value="NAME " * 10000, # Large memory block
# Create 65 messages (similar to the failing case in the bug report)
# Pattern: system + alternating user/assistant messages
messages = [
PydanticMessage(
role=MessageRole.system,
content=[TextContent(type="text", text="You are a helpful assistant.")],
)
]
# Add 64 more messages (32 user-assistant pairs)
for i in range(32):
messages.append(
PydanticMessage(
role=MessageRole.user,
content=[TextContent(type="text", text=f"User message {i}")],
)
)
messages.append(
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(type="text", text=f"Assistant response {i}")],
)
],
)
agent_state = await server.agent_manager.create_agent_async(agent_create, actor=actor)
# Create a run for the agent using RunManager
run = PydanticRun(agent_id=agent_state.id)
run = await RunManager().create_run(pydantic_run=run, actor=actor)
# Create the agent loop using LettaAgentV3
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
# message the agent
input_message = MessageCreate(role=MessageRole.user, content="Hello")
# Call step on the agent - may trigger summarization due to large context
from letta.errors import SystemPromptTokenExceededError
with pytest.raises(SystemPromptTokenExceededError):
response = await agent_loop.step(
input_messages=[input_message],
run_id=run.id,
max_steps=3,
)
# Repair the agent by shortening the memory blocks and system prompt
# Update system prompt to a shorter version
short_system_prompt = "You are a helpful assistant."
await server.agent_manager.update_agent_async(
agent_id=agent_state.id,
agent_update=UpdateAgent(system=short_system_prompt),
actor=actor,
)
assert len(messages) == 65, f"Expected 65 messages, got {len(messages)}"
# Update memory block to a shorter version
short_memory_value = "The user's name is Alice."
await server.agent_manager.modify_block_by_label_async(
agent_id=agent_state.id,
block_label="human",
block_update=BlockUpdate(value=short_memory_value),
actor=actor,
)
# This should NOT raise "No assistant message found from indices 650 to 65"
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
try:
summary, remaining_messages = await self_summarize_sliding_window(
actor=actor,
agent_id="agent-test-self-sliding-window",
agent_llm_config=llm_config,
telemetry_manager=TelemetryManager(),
llm_client=LLMClient.create(llm_config),
agent_type=AgentType.letta_v1_agent,
messages=messages,
compaction_settings=summarizer_config,
timezone="UTC",
)
# Reload agent state after repairs
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
print("REPAIRED AGENT STATE ======")
print(agent_state.system)
print(agent_state.blocks)
# Verify the summary was generated (actual LLM response)
assert summary is not None
assert len(summary) > 0
# Create a new run for the repaired agent
run = PydanticRun(agent_id=agent_state.id)
run = await RunManager().create_run(pydantic_run=run, actor=actor)
# Verify remaining messages is a valid subset
assert len(remaining_messages) < len(messages)
assert len(remaining_messages) > 0
# Create a new agent loop with the repaired agent state
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
print(f"Using {llm_config.model_endpoint_type} token counter for model {llm_config.model}")
# Now the agent should be able to respond without context window errors
response = await agent_loop.step(
input_messages=[input_message],
run_id=run.id,
max_steps=3,
)
except ValueError as e:
if "No assistant message found from indices" in str(e):
# Extract the indices from the error message
import re
# Verify we got a valid response after repair
assert response is not None
assert response.messages is not None
print(f"Agent successfully responded after repair with {len(response.messages)} messages")
match = re.search(r"from indices (\d+) to (\d+)", str(e))
if match:
start_idx, end_idx = int(match.group(1)), int(match.group(2))
pytest.fail(
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
f"Error: {e}"
)
raise
### NOTE: removing edge case test where sys prompt is huge for now
### because we no longer refresh the system prompt before compaction
### in order to leverage caching (for self compaction)
# @pytest.mark.asyncio
# @pytest.mark.parametrize(
# "llm_config",
# TESTED_LLM_CONFIGS,
# ids=[c.model for c in TESTED_LLM_CONFIGS],
# )
# async def test_large_system_prompt_summarization(server: SyncServer, actor, llm_config: LLMConfig):
# """
# Test edge case of large system prompt / memory blocks.
# This test verifies that summarization handles the case where the system prompt
# and memory blocks are very large, potentially consuming most of the context window.
# The summarizer should gracefully handle this scenario without errors.
# """
# # Override context window to be small so we trigger summarization
# llm_config.context_window = 10000
# # Create agent with large system prompt and memory blocks
# agent_name = f"test_agent_large_system_prompt_{llm_config.model}".replace(".", "_").replace("/", "_")
# agent_create = CreateAgent(
# name=agent_name,
# llm_config=llm_config,
# embedding_config=DEFAULT_EMBEDDING_CONFIG,
# system="SYSTEM PROMPT " * 10000, # Large system prompt
# memory_blocks=[
# CreateBlock(
# label="human",
# limit=200000,
# value="NAME " * 10000, # Large memory block
# )
# ],
# )
# agent_state = await server.agent_manager.create_agent_async(agent_create, actor=actor)
# # Create a run for the agent using RunManager
# run = PydanticRun(agent_id=agent_state.id)
# run = await RunManager().create_run(pydantic_run=run, actor=actor)
# # Create the agent loop using LettaAgentV3
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
# # message the agent
# input_message = MessageCreate(role=MessageRole.user, content="Hello")
# # Call step on the agent - may trigger summarization due to large context
# from letta.errors import SystemPromptTokenExceededError
# with pytest.raises(SystemPromptTokenExceededError):
# response = await agent_loop.step(
# input_messages=[input_message],
# run_id=run.id,
# max_steps=3,
# )
# # Repair the agent by shortening the memory blocks and system prompt
# # Update system prompt to a shorter version
# short_system_prompt = "You are a helpful assistant."
# await server.agent_manager.update_agent_async(
# agent_id=agent_state.id,
# agent_update=UpdateAgent(system=short_system_prompt),
# actor=actor,
# )
# # Update memory block to a shorter version
# short_memory_value = "The user's name is Alice."
# await server.agent_manager.modify_block_by_label_async(
# agent_id=agent_state.id,
# block_label="human",
# block_update=BlockUpdate(value=short_memory_value),
# actor=actor,
# )
# # Reload agent state after repairs
# agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
# print("REPAIRED AGENT STATE ======")
# print(agent_state.system)
# print(agent_state.blocks)
# # Create a new run for the repaired agent
# run = PydanticRun(agent_id=agent_state.id)
# run = await RunManager().create_run(pydantic_run=run, actor=actor)
# # Create a new agent loop with the repaired agent state
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
# # Now the agent should be able to respond without context window errors
# response = await agent_loop.step(
# input_messages=[input_message],
# run_id=run.id,
# max_steps=3,
# )
# # Verify we got a valid response after repair
# assert response is not None
# assert response.messages is not None
# print(f"Agent successfully responded after repair with {len(response.messages)} messages")
# @pytest.mark.asyncio
@@ -1718,6 +1834,127 @@ async def test_summarize_all(server: SyncServer, actor, llm_config: LLMConfig):
print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
async def test_summarize_self(server: SyncServer, actor, llm_config: LLMConfig):
"""
Test the summarize_all function with real LLM calls.
This test verifies that the 'all' summarization mode works correctly,
summarizing the entire conversation into a single summary string.
"""
from letta.llm_api.llm_client import LLMClient
from letta.schemas.agent import AgentType
from letta.services.summarizer.self_summarizer import self_summarize_all
from letta.services.summarizer.summarizer_config import CompactionSettings
from letta.services.telemetry_manager import TelemetryManager
# Create a summarizer config with "self" mode
handle = llm_config.handle or f"{llm_config.model_endpoint_type}/{llm_config.model}"
summarizer_config = CompactionSettings(model=handle)
summarizer_config.mode = "self"
# Create test messages - a simple conversation
messages = [
PydanticMessage(
role=MessageRole.system,
content=[TextContent(type="text", text="You are a helpful assistant.")],
)
]
# Add 10 user-assistant pairs
for i in range(10):
messages.append(
PydanticMessage(
role=MessageRole.user,
content=[TextContent(type="text", text=f"User message {i}: What is {i} + {i}?")],
)
)
messages.append(
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(type="text", text=f"Assistant response {i}: {i} + {i} = {i * 2}.")],
)
)
assert len(messages) == 21, f"Expected 21 messages, got {len(messages)}"
# Call summarize_all with real LLM
summary, new_in_context_messages = await self_summarize_all(
actor=actor,
agent_id="agent-test-self-sliding-window",
agent_llm_config=llm_config,
telemetry_manager=TelemetryManager(),
llm_client=LLMClient.create(llm_config),
agent_type=AgentType.letta_v1_agent,
messages=messages,
compaction_settings=summarizer_config,
timezone="UTC",
)
# Verify the summary was generated
assert len(new_in_context_messages) == 1
assert summary is not None
assert len(summary) > 0
assert len(summary) <= 5000 # length should be less than 500 words, give some buffer in test
print(f"Successfully summarized {len(messages)} messages using 'self' mode")
print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}")
@pytest.mark.asyncio
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
async def test_self_mode_fallback(server: SyncServer, actor, llm_config: LLMConfig):
"""If self summarize fails, it should have proper fallback."""
from unittest.mock import AsyncMock, patch
messages = [
PydanticMessage(
role=MessageRole.system,
content=[TextContent(type="text", text="You are a helpful assistant.")],
)
]
for i in range(10):
messages.append(
PydanticMessage(
role=MessageRole.user,
content=[TextContent(type="text", text=f"User message {i}: Test message {i}.")],
)
)
messages.append(
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(type="text", text=f"Assistant response {i}: Acknowledged message {i}.")],
)
)
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
handle = llm_config.handle or f"{llm_config.model_endpoint_type}/{llm_config.model}"
agent_state.compaction_settings = CompactionSettings(model=handle, mode="self_compact_all")
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
# Mock self_summarize_all to always fail
with patch(
"letta.services.summarizer.compact.self_summarize_all",
new_callable=AsyncMock,
side_effect=RuntimeError("Simulated self_summarize_all failure"),
):
summary_message, compacted_messages, summary_text = await agent_loop.compact(messages=in_context_messages)
assert summary_message is not None
assert summary_text is not None
assert len(summary_text) > 0
assert len(compacted_messages) < len(in_context_messages)
print(f"Fallback succeeded: {len(in_context_messages)} -> {len(compacted_messages)} messages")
# =============================================================================
# CompactionStats tests
# =============================================================================
@@ -2033,3 +2270,15 @@ async def test_compact_with_stats_params_embeds_stats(server: SyncServer, actor,
assert stats.context_tokens_after is not None # Should be set by compact()
assert stats.messages_count_after == len(compacted_messages) # final_messages already includes summary
assert stats.context_window == llm_config.context_window
### basic self summarization
### fallback chain
### basic self sliding window summarization
### self sliding window preserves recent msgs
### self mode return compaction stats

View File

@@ -209,7 +209,7 @@ class TestSummarizeSlidingWindowTelemetryContext:
await summarizer_sliding_window.summarize_via_sliding_window(
actor=mock_actor,
llm_config=mock_llm_config,
agent_llm_config=mock_llm_config, # case where agent and summarizer have same config
agent_llm_config=mock_llm_config, # case where agent and summarizer have same config
summarizer_config=mock_compaction_settings,
in_context_messages=mock_messages,
agent_id=agent_id,