feat: refactor summarization and message persistence code [LET-6464] (#6561)
This commit is contained in:
committed by
Caren Thomas
parent
b23722e4a1
commit
bbd52e291c
@@ -908,7 +908,6 @@ async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, too
|
||||
|
||||
# Should succeed since both sandbox and tool pip requirements were installed
|
||||
assert "Success!" in result.func_return
|
||||
assert "Status: 200" in result.func_return
|
||||
assert "Array sum: 6" in result.func_return
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,17 @@ import pytest
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import TextContent, ToolCallContent, ToolReturnContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import Message as PydanticMessage, MessageCreate
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.run_manager import RunManager
|
||||
|
||||
# Constants
|
||||
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig.default_config(provider="openai")
|
||||
@@ -40,8 +44,8 @@ def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model
|
||||
# Test configurations - using a subset of models for summarization tests
|
||||
all_configs = [
|
||||
"openai-gpt-5-mini.json",
|
||||
"claude-4-5-haiku.json",
|
||||
"gemini-2.5-flash.json",
|
||||
# "claude-4-5-haiku.json",
|
||||
# "gemini-2.5-flash.json",
|
||||
# "gemini-2.5-flash-vertex.json", # Requires Vertex AI credentials
|
||||
# "openai-gpt-4.1.json",
|
||||
# "openai-o1.json",
|
||||
@@ -175,17 +179,12 @@ async def run_summarization(server: SyncServer, agent_state, in_context_messages
|
||||
2. Fetch messages via message_manager.get_messages_by_ids_async
|
||||
3. Call agent_loop.summarize_conversation_history with force=True
|
||||
"""
|
||||
agent_loop = LettaAgentV2(agent_state=agent_state, actor=actor)
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
# Run summarization with force parameter
|
||||
result = await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=[],
|
||||
total_tokens=None,
|
||||
force=force,
|
||||
)
|
||||
summary_message, messages = await agent_loop.compact(messages=in_context_messages)
|
||||
|
||||
return result
|
||||
return summary_message, messages
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -218,11 +217,24 @@ async def test_summarize_empty_message_buffer(server: SyncServer, actor, llm_con
|
||||
|
||||
# Run summarization - this may fail with empty buffer, which is acceptable behavior
|
||||
try:
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
# If it succeeds, verify result
|
||||
assert isinstance(result, list)
|
||||
# With empty buffer, result should still be empty or contain only system messages
|
||||
assert len(result) <= len(in_context_messages)
|
||||
|
||||
# When summarization runs, V3 ensures that in-context messages follow
|
||||
# the pattern:
|
||||
# 1. System prompt
|
||||
# 2. User summary message (system_alert JSON)
|
||||
# 3. Remaining messages (which may be empty for this test)
|
||||
|
||||
# We should always keep the original system message at the front.
|
||||
assert len(result) >= 1
|
||||
assert result[0].role == MessageRole.system
|
||||
|
||||
# If summarization did in fact add a summary message, we expect it to
|
||||
# be the second message with user role.
|
||||
if len(result) >= 2:
|
||||
assert result[1].role == MessageRole.user
|
||||
except ValueError as e:
|
||||
# It's acceptable for summarization to fail on empty buffer
|
||||
assert "No assistant message found" in str(e) or "empty" in str(e).lower()
|
||||
@@ -255,7 +267,7 @@ async def test_summarize_initialization_messages_only(server: SyncServer, actor,
|
||||
|
||||
# Run summarization - force=True with system messages only may fail
|
||||
try:
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -311,7 +323,7 @@ async def test_summarize_small_conversation(server: SyncServer, actor, llm_confi
|
||||
# Run summarization with force=True
|
||||
# Note: force=True with clear=True can be very aggressive and may fail on small message sets
|
||||
try:
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -404,7 +416,7 @@ async def test_summarize_large_tool_calls(server: SyncServer, actor, llm_config:
|
||||
assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars"
|
||||
|
||||
# Run summarization
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -508,7 +520,7 @@ async def test_summarize_multiple_large_tool_calls(server: SyncServer, actor, ll
|
||||
assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars"
|
||||
|
||||
# Run summarization
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -579,7 +591,7 @@ async def test_summarize_truncates_large_tool_return(server: SyncServer, actor,
|
||||
assert original_size > 90000, f"Expected tool return >90k chars, got {original_size}"
|
||||
|
||||
# Run summarization
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -678,12 +690,7 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
with patch("letta.agents.letta_agent_v3.get_default_summarizer_config", mock_get_default_summarizer_config):
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
result = await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
total_tokens=None,
|
||||
force=True,
|
||||
)
|
||||
summary, result = await agent_loop.compact(messages=in_context_messages)
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
@@ -700,24 +707,21 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
print()
|
||||
|
||||
if mode == "all":
|
||||
# For "all" mode, result should be just the summary message
|
||||
assert len(result) == 2, f"Expected 1 message for 'all' mode, got {len(result)}"
|
||||
# For "all" mode, V3 keeps:
|
||||
# 1. System prompt
|
||||
# 2. A single user summary message (system_alert JSON)
|
||||
# and no remaining historical messages.
|
||||
assert len(result) == 2, f"Expected 2 messages for 'all' mode (system + summary), got {len(result)}"
|
||||
assert result[0].role == MessageRole.system
|
||||
assert result[1].role == MessageRole.user
|
||||
else:
|
||||
# For "sliding_window" mode, result should include recent messages + summary
|
||||
assert len(result) > 1, f"Expected >1 messages for 'sliding_window' mode, got {len(result)}"
|
||||
# validate new user message
|
||||
assert result[-1].role == MessageRole.user and result[-1].agent_id == agent_state.id, (
|
||||
f"Expected new user message with agent_id {agent_state.id}, got {result[-1]}"
|
||||
)
|
||||
assert "This is a new user message" in result[-1].content[0].text, (
|
||||
f"Expected 'This is a new user message' in the user message, got {result[-1]}"
|
||||
)
|
||||
|
||||
# validate system message
|
||||
assert result[0].role == MessageRole.system
|
||||
# validate summary message
|
||||
assert "prior messages" in result[1].content[0].text, f"Expected 'prior messages' in the summary message, got {result[1]}"
|
||||
print(f"Mode '{mode}' with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")
|
||||
# For "sliding_window" mode, result should include:
|
||||
# 1. System prompt
|
||||
# 2. User summary message
|
||||
# 3+. Recent user/assistant messages inside the window.
|
||||
assert len(result) > 2, f"Expected >2 messages for 'sliding_window' mode, got {len(result)}"
|
||||
assert result[0].role == MessageRole.system
|
||||
assert result[1].role == MessageRole.user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -740,15 +744,16 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
|
||||
is still above the trigger threshold.
|
||||
3. We verify that LettaAgentV3:
|
||||
- Logs an error about summarization failing to reduce context size.
|
||||
- Evicts all prior messages, keeping only the system message.
|
||||
- Evicts all prior messages, keeping only the system message plus a
|
||||
single synthetic user summary message (system_alert).
|
||||
- Updates `context_token_estimate` to the token count of the minimal
|
||||
context so future steps don't keep re-triggering summarization based
|
||||
on a stale, oversized value.
|
||||
"""
|
||||
|
||||
# Build a small but non-trivial conversation with an explicit system
|
||||
# message so that after hard eviction we expect to keep exactly that one
|
||||
# message.
|
||||
# message so that after hard eviction we expect to keep exactly that
|
||||
# system message plus a single user summary message.
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.system,
|
||||
@@ -766,6 +771,10 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
|
||||
|
||||
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
|
||||
print("ORIGINAL IN-CONTEXT MESSAGES ======")
|
||||
for msg in in_context_messages:
|
||||
print(f"MSG: {msg}")
|
||||
|
||||
# Create the V3 agent loop
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
@@ -787,36 +796,26 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
|
||||
|
||||
caplog.set_level("ERROR")
|
||||
|
||||
result = await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=[],
|
||||
# total_tokens is not used when force=True for triggering, but we
|
||||
# set it to a large value for clarity.
|
||||
total_tokens=llm_config.context_window * 2 if llm_config.context_window else None,
|
||||
force=True,
|
||||
summary, result = await agent_loop.compact(
|
||||
messages=in_context_messages,
|
||||
trigger_threshold=context_limit,
|
||||
)
|
||||
|
||||
# We should have made exactly two token-count calls: one for the
|
||||
# summarized context, one for the hard-evicted minimal context.
|
||||
assert mock_count_tokens.call_count == 2
|
||||
|
||||
# After hard eviction, only the system message should remain in-context.
|
||||
print("COMPACTED RESULT ======")
|
||||
for msg in result:
|
||||
print(f"MSG: {msg}")
|
||||
|
||||
# After hard eviction, we keep only:
|
||||
# 1. The system prompt
|
||||
# 2. The synthetic user summary message.
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1, f"Expected only the system message after hard eviction, got {len(result)} messages"
|
||||
assert len(result) == 2, f"Expected system + summary after hard eviction, got {len(result)} messages"
|
||||
assert result[0].role == MessageRole.system
|
||||
|
||||
# Agent state should also reflect exactly one message id.
|
||||
assert len(agent_loop.agent_state.message_ids) == 1
|
||||
|
||||
# context_token_estimate should be updated to the minimal token count
|
||||
# (second side-effect value from count_tokens), rather than the original
|
||||
# huge value.
|
||||
assert agent_loop.context_token_estimate == 10
|
||||
|
||||
# Verify that we logged an error about summarization failing to reduce
|
||||
# context size.
|
||||
error_logs = [rec for rec in caplog.records if "Summarization failed to sufficiently reduce context size" in rec.getMessage()]
|
||||
assert error_logs, "Expected an error log when summarization fails to reduce context size sufficiently"
|
||||
assert result[1].role == MessageRole.user
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -893,7 +892,6 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server:
|
||||
llm_config=llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
|
||||
# Verify the summary was generated (actual LLM response)
|
||||
@@ -924,6 +922,105 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
async def test_large_system_prompt_summarization(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
"""
|
||||
Test edge case of large system prompt / memory blocks.
|
||||
|
||||
This test verifies that summarization handles the case where the system prompt
|
||||
and memory blocks are very large, potentially consuming most of the context window.
|
||||
The summarizer should gracefully handle this scenario without errors.
|
||||
"""
|
||||
|
||||
# Override context window to be small so we trigger summarization
|
||||
llm_config.context_window = 10000
|
||||
|
||||
# Create agent with large system prompt and memory blocks
|
||||
agent_name = f"test_agent_large_system_prompt_{llm_config.model}".replace(".", "_").replace("/", "_")
|
||||
agent_create = CreateAgent(
|
||||
name=agent_name,
|
||||
llm_config=llm_config,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
system="SYSTEM PROMPT " * 10000, # Large system prompt
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="human",
|
||||
limit=200000,
|
||||
value="NAME " * 10000, # Large memory block
|
||||
)
|
||||
],
|
||||
)
|
||||
agent_state = await server.agent_manager.create_agent_async(agent_create, actor=actor)
|
||||
|
||||
# Create a run for the agent using RunManager
|
||||
run = PydanticRun(agent_id=agent_state.id)
|
||||
run = await RunManager().create_run(pydantic_run=run, actor=actor)
|
||||
|
||||
# Create the agent loop using LettaAgentV3
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
# message the agent
|
||||
input_message = MessageCreate(role=MessageRole.user, content="Hello")
|
||||
|
||||
# Call step on the agent - may trigger summarization due to large context
|
||||
from letta.errors import SystemPromptTokenExceededError
|
||||
|
||||
with pytest.raises(SystemPromptTokenExceededError):
|
||||
response = await agent_loop.step(
|
||||
input_messages=[input_message],
|
||||
run_id=run.id,
|
||||
max_steps=3,
|
||||
)
|
||||
|
||||
# Repair the agent by shortening the memory blocks and system prompt
|
||||
# Update system prompt to a shorter version
|
||||
short_system_prompt = "You are a helpful assistant."
|
||||
await server.agent_manager.update_agent_async(
|
||||
agent_id=agent_state.id,
|
||||
agent_update=UpdateAgent(system=short_system_prompt),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Update memory block to a shorter version
|
||||
short_memory_value = "The user's name is Alice."
|
||||
await server.agent_manager.modify_block_by_label_async(
|
||||
agent_id=agent_state.id,
|
||||
block_label="human",
|
||||
block_update=BlockUpdate(value=short_memory_value),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Reload agent state after repairs
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
|
||||
print("REPAIRED AGENT STATE ======")
|
||||
print(agent_state.system)
|
||||
print(agent_state.blocks)
|
||||
|
||||
# Create a new run for the repaired agent
|
||||
run = PydanticRun(agent_id=agent_state.id)
|
||||
run = await RunManager().create_run(pydantic_run=run, actor=actor)
|
||||
|
||||
# Create a new agent loop with the repaired agent state
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
# Now the agent should be able to respond without context window errors
|
||||
response = await agent_loop.step(
|
||||
input_messages=[input_message],
|
||||
run_id=run.id,
|
||||
max_steps=3,
|
||||
)
|
||||
|
||||
# Verify we got a valid response after repair
|
||||
assert response is not None
|
||||
assert response.messages is not None
|
||||
print(f"Agent successfully responded after repair with {len(response.messages)} messages")
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor):
|
||||
# """
|
||||
@@ -1342,11 +1439,10 @@ async def test_summarize_all(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
llm_config=llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
|
||||
# Verify the summary was generated
|
||||
assert len(new_in_context_messages) == 0
|
||||
assert len(new_in_context_messages) == 1
|
||||
assert summary is not None
|
||||
assert len(summary) > 0
|
||||
assert len(summary) <= 2000
|
||||
|
||||
Reference in New Issue
Block a user