feat: refactor summarization and message persistence code [LET-6464] (#6561)

This commit is contained in:
Sarah Wooders
2025-12-09 16:34:06 -08:00
committed by Caren Thomas
parent b23722e4a1
commit bbd52e291c
10 changed files with 493 additions and 434 deletions

View File

@@ -908,7 +908,6 @@ async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, too
# Should succeed since both sandbox and tool pip requirements were installed
assert "Success!" in result.func_return
assert "Status: 200" in result.func_return
assert "Array sum: 6" in result.func_return

View File

@@ -16,13 +16,17 @@ import pytest
from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.agents.letta_agent_v3 import LettaAgentV3
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.block import BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import TextContent, ToolCallContent, ToolReturnContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import Message as PydanticMessage, MessageCreate
from letta.schemas.run import Run as PydanticRun
from letta.server.server import SyncServer
from letta.services.run_manager import RunManager
# Constants
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig.default_config(provider="openai")
@@ -40,8 +44,8 @@ def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model
# Test configurations - using a subset of models for summarization tests
all_configs = [
"openai-gpt-5-mini.json",
"claude-4-5-haiku.json",
"gemini-2.5-flash.json",
# "claude-4-5-haiku.json",
# "gemini-2.5-flash.json",
# "gemini-2.5-flash-vertex.json", # Requires Vertex AI credentials
# "openai-gpt-4.1.json",
# "openai-o1.json",
@@ -175,17 +179,12 @@ async def run_summarization(server: SyncServer, agent_state, in_context_messages
2. Fetch messages via message_manager.get_messages_by_ids_async
3. Call agent_loop.summarize_conversation_history with force=True
"""
agent_loop = LettaAgentV2(agent_state=agent_state, actor=actor)
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
# Run summarization with force parameter
result = await agent_loop.summarize_conversation_history(
in_context_messages=in_context_messages,
new_letta_messages=[],
total_tokens=None,
force=force,
)
summary_message, messages = await agent_loop.compact(messages=in_context_messages)
return result
return summary_message, messages
# ======================================================================================================================
@@ -218,11 +217,24 @@ async def test_summarize_empty_message_buffer(server: SyncServer, actor, llm_con
# Run summarization - this may fail with empty buffer, which is acceptable behavior
try:
result = await run_summarization(server, agent_state, in_context_messages, actor)
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
# If it succeeds, verify result
assert isinstance(result, list)
# With empty buffer, result should still be empty or contain only system messages
assert len(result) <= len(in_context_messages)
# When summarization runs, V3 ensures that in-context messages follow
# the pattern:
# 1. System prompt
# 2. User summary message (system_alert JSON)
# 3. Remaining messages (which may be empty for this test)
# We should always keep the original system message at the front.
assert len(result) >= 1
assert result[0].role == MessageRole.system
# If summarization did in fact add a summary message, we expect it to
# be the second message with user role.
if len(result) >= 2:
assert result[1].role == MessageRole.user
except ValueError as e:
# It's acceptable for summarization to fail on empty buffer
assert "No assistant message found" in str(e) or "empty" in str(e).lower()
@@ -255,7 +267,7 @@ async def test_summarize_initialization_messages_only(server: SyncServer, actor,
# Run summarization - force=True with system messages only may fail
try:
result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
summary, result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
# Verify result
assert isinstance(result, list)
@@ -311,7 +323,7 @@ async def test_summarize_small_conversation(server: SyncServer, actor, llm_confi
# Run summarization with force=True
# Note: force=True with clear=True can be very aggressive and may fail on small message sets
try:
result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
summary, result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
# Verify result
assert isinstance(result, list)
@@ -404,7 +416,7 @@ async def test_summarize_large_tool_calls(server: SyncServer, actor, llm_config:
assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars"
# Run summarization
result = await run_summarization(server, agent_state, in_context_messages, actor)
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
# Verify result
assert isinstance(result, list)
@@ -508,7 +520,7 @@ async def test_summarize_multiple_large_tool_calls(server: SyncServer, actor, ll
assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars"
# Run summarization
result = await run_summarization(server, agent_state, in_context_messages, actor)
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
# Verify result
assert isinstance(result, list)
@@ -579,7 +591,7 @@ async def test_summarize_truncates_large_tool_return(server: SyncServer, actor,
assert original_size > 90000, f"Expected tool return >90k chars, got {original_size}"
# Run summarization
result = await run_summarization(server, agent_state, in_context_messages, actor)
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
# Verify result
assert isinstance(result, list)
@@ -678,12 +690,7 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
with patch("letta.agents.letta_agent_v3.get_default_summarizer_config", mock_get_default_summarizer_config):
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
result = await agent_loop.summarize_conversation_history(
in_context_messages=in_context_messages,
new_letta_messages=new_letta_messages,
total_tokens=None,
force=True,
)
summary, result = await agent_loop.compact(messages=in_context_messages)
assert isinstance(result, list)
@@ -700,24 +707,21 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
print()
if mode == "all":
# For "all" mode, result should be just the summary message
assert len(result) == 2, f"Expected 1 message for 'all' mode, got {len(result)}"
# For "all" mode, V3 keeps:
# 1. System prompt
# 2. A single user summary message (system_alert JSON)
# and no remaining historical messages.
assert len(result) == 2, f"Expected 2 messages for 'all' mode (system + summary), got {len(result)}"
assert result[0].role == MessageRole.system
assert result[1].role == MessageRole.user
else:
# For "sliding_window" mode, result should include recent messages + summary
assert len(result) > 1, f"Expected >1 messages for 'sliding_window' mode, got {len(result)}"
# validate new user message
assert result[-1].role == MessageRole.user and result[-1].agent_id == agent_state.id, (
f"Expected new user message with agent_id {agent_state.id}, got {result[-1]}"
)
assert "This is a new user message" in result[-1].content[0].text, (
f"Expected 'This is a new user message' in the user message, got {result[-1]}"
)
# validate system message
assert result[0].role == MessageRole.system
# validate summary message
assert "prior messages" in result[1].content[0].text, f"Expected 'prior messages' in the summary message, got {result[1]}"
print(f"Mode '{mode}' with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")
# For "sliding_window" mode, result should include:
# 1. System prompt
# 2. User summary message
# 3+. Recent user/assistant messages inside the window.
assert len(result) > 2, f"Expected >2 messages for 'sliding_window' mode, got {len(result)}"
assert result[0].role == MessageRole.system
assert result[1].role == MessageRole.user
@pytest.mark.asyncio
@@ -740,15 +744,16 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
is still above the trigger threshold.
3. We verify that LettaAgentV3:
- Logs an error about summarization failing to reduce context size.
- Evicts all prior messages, keeping only the system message.
- Evicts all prior messages, keeping only the system message plus a
single synthetic user summary message (system_alert).
- Updates `context_token_estimate` to the token count of the minimal
context so future steps don't keep re-triggering summarization based
on a stale, oversized value.
"""
# Build a small but non-trivial conversation with an explicit system
# message so that after hard eviction we expect to keep exactly that one
# message.
# message so that after hard eviction we expect to keep exactly that
# system message plus a single user summary message.
messages = [
PydanticMessage(
role=MessageRole.system,
@@ -766,6 +771,10 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
print("ORIGINAL IN-CONTEXT MESSAGES ======")
for msg in in_context_messages:
print(f"MSG: {msg}")
# Create the V3 agent loop
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
@@ -787,36 +796,26 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
caplog.set_level("ERROR")
result = await agent_loop.summarize_conversation_history(
in_context_messages=in_context_messages,
new_letta_messages=[],
# total_tokens is not used when force=True for triggering, but we
# set it to a large value for clarity.
total_tokens=llm_config.context_window * 2 if llm_config.context_window else None,
force=True,
summary, result = await agent_loop.compact(
messages=in_context_messages,
trigger_threshold=context_limit,
)
# We should have made exactly two token-count calls: one for the
# summarized context, one for the hard-evicted minimal context.
assert mock_count_tokens.call_count == 2
# After hard eviction, only the system message should remain in-context.
print("COMPACTED RESULT ======")
for msg in result:
print(f"MSG: {msg}")
# After hard eviction, we keep only:
# 1. The system prompt
# 2. The synthetic user summary message.
assert isinstance(result, list)
assert len(result) == 1, f"Expected only the system message after hard eviction, got {len(result)} messages"
assert len(result) == 2, f"Expected system + summary after hard eviction, got {len(result)} messages"
assert result[0].role == MessageRole.system
# Agent state should also reflect exactly one message id.
assert len(agent_loop.agent_state.message_ids) == 1
# context_token_estimate should be updated to the minimal token count
# (second side-effect value from count_tokens), rather than the original
# huge value.
assert agent_loop.context_token_estimate == 10
# Verify that we logged an error about summarization failing to reduce
# context size.
error_logs = [rec for rec in caplog.records if "Summarization failed to sufficiently reduce context size" in rec.getMessage()]
assert error_logs, "Expected an error log when summarization fails to reduce context size sufficiently"
assert result[1].role == MessageRole.user
# ======================================================================================================================
@@ -893,7 +892,6 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server:
llm_config=llm_config,
summarizer_config=summarizer_config,
in_context_messages=messages,
new_messages=[],
)
# Verify the summary was generated (actual LLM response)
@@ -924,6 +922,105 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server:
raise
@pytest.mark.asyncio
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
async def test_large_system_prompt_summarization(server: SyncServer, actor, llm_config: LLMConfig):
"""
Test edge case of large system prompt / memory blocks.
This test verifies that summarization handles the case where the system prompt
and memory blocks are very large, potentially consuming most of the context window.
The summarizer should gracefully handle this scenario without errors.
"""
# Override context window to be small so we trigger summarization
llm_config.context_window = 10000
# Create agent with large system prompt and memory blocks
agent_name = f"test_agent_large_system_prompt_{llm_config.model}".replace(".", "_").replace("/", "_")
agent_create = CreateAgent(
name=agent_name,
llm_config=llm_config,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
system="SYSTEM PROMPT " * 10000, # Large system prompt
memory_blocks=[
CreateBlock(
label="human",
limit=200000,
value="NAME " * 10000, # Large memory block
)
],
)
agent_state = await server.agent_manager.create_agent_async(agent_create, actor=actor)
# Create a run for the agent using RunManager
run = PydanticRun(agent_id=agent_state.id)
run = await RunManager().create_run(pydantic_run=run, actor=actor)
# Create the agent loop using LettaAgentV3
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
# message the agent
input_message = MessageCreate(role=MessageRole.user, content="Hello")
# Call step on the agent - may trigger summarization due to large context
from letta.errors import SystemPromptTokenExceededError
with pytest.raises(SystemPromptTokenExceededError):
response = await agent_loop.step(
input_messages=[input_message],
run_id=run.id,
max_steps=3,
)
# Repair the agent by shortening the memory blocks and system prompt
# Update system prompt to a shorter version
short_system_prompt = "You are a helpful assistant."
await server.agent_manager.update_agent_async(
agent_id=agent_state.id,
agent_update=UpdateAgent(system=short_system_prompt),
actor=actor,
)
# Update memory block to a shorter version
short_memory_value = "The user's name is Alice."
await server.agent_manager.modify_block_by_label_async(
agent_id=agent_state.id,
block_label="human",
block_update=BlockUpdate(value=short_memory_value),
actor=actor,
)
# Reload agent state after repairs
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
print("REPAIRED AGENT STATE ======")
print(agent_state.system)
print(agent_state.blocks)
# Create a new run for the repaired agent
run = PydanticRun(agent_id=agent_state.id)
run = await RunManager().create_run(pydantic_run=run, actor=actor)
# Create a new agent loop with the repaired agent state
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
# Now the agent should be able to respond without context window errors
response = await agent_loop.step(
input_messages=[input_message],
run_id=run.id,
max_steps=3,
)
# Verify we got a valid response after repair
assert response is not None
assert response.messages is not None
print(f"Agent successfully responded after repair with {len(response.messages)} messages")
# @pytest.mark.asyncio
# async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor):
# """
@@ -1342,11 +1439,10 @@ async def test_summarize_all(server: SyncServer, actor, llm_config: LLMConfig):
llm_config=llm_config,
summarizer_config=summarizer_config,
in_context_messages=messages,
new_messages=[],
)
# Verify the summary was generated
assert len(new_in_context_messages) == 0
assert len(new_in_context_messages) == 1
assert summary is not None
assert len(summary) > 0
assert len(summary) <= 2000