feat: Rework summarizer (#654)

This commit is contained in:
Matthew Zhou
2025-01-22 09:19:26 -10:00
committed by GitHub
parent c589e7cafb
commit 50de3cb4b7
11 changed files with 289 additions and 165 deletions

View File

@@ -1,6 +1,7 @@
import json
import os
import uuid
from datetime import datetime
from typing import List
import pytest
@@ -8,9 +9,13 @@ import pytest
from letta import create_client
from letta.agent import Agent
from letta.client.client import LocalClient
from letta.errors import ContextWindowExceededError
from letta.llm_api.helpers import calculate_summarizer_cutoff
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.settings import summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
from tests.helpers.utils import cleanup
@@ -44,6 +49,101 @@ def agent_state(client):
client.delete_agent(agent_state.id)
# Sample data setup
def generate_message(role: str, text: str = None, tool_calls: List = None) -> Message:
"""Helper to generate a Message object."""
return Message(
id="message-" + str(uuid.uuid4()),
role=MessageRole(role),
text=text or f"{role} message text",
created_at=datetime.utcnow(),
tool_calls=tool_calls or [],
)
def test_cutoff_calculation(mocker):
"""Test basic scenarios where the function calculates the cutoff correctly."""
# Arrange
logger = mocker.Mock() # Mock logger
messages = [
generate_message("system"),
generate_message("user"),
generate_message("assistant"),
generate_message("user"),
generate_message("assistant"),
]
mocker.patch("letta.settings.summarizer_settings.desired_memory_token_pressure", 0.5)
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", False)
# Basic tests
token_counts = [4, 2, 8, 2, 2]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 3
assert messages[cutoff - 1].role == MessageRole.assistant
token_counts = [4, 2, 2, 2, 2]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 5
assert messages[cutoff - 1].role == MessageRole.assistant
token_counts = [2, 2, 3, 2, 2]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 3
assert messages[cutoff - 1].role == MessageRole.assistant
# Evict all messages
# Should give the end of the token_counts, even though it is not necessary (can just evict up to the 100)
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", True)
token_counts = [1, 1, 100, 1, 1]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 5
assert messages[cutoff - 1].role == MessageRole.assistant
# Don't evict all messages with same token_counts, cutoff now should be at the 100
# Should give the end of the token_counts, even though it is not necessary (can just evict up to the 100)
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", False)
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 3
assert messages[cutoff - 1].role == MessageRole.assistant
# Set `keep_last_n_messages`
mocker.patch("letta.settings.summarizer_settings.keep_last_n_messages", 3)
token_counts = [4, 2, 2, 2, 2]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 2
assert messages[cutoff - 1].role == MessageRole.user
def test_summarize_many_messages_basic(client, mock_e2b_api_key_none):
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 3000
small_agent_state = client.create_agent(
name="small_context_agent",
llm_config=small_context_llm_config,
)
for _ in range(10):
client.user_message(
agent_id=small_agent_state.id,
message="hi " * 60,
)
client.delete_agent(small_agent_state.id)
def test_summarize_large_message_does_not_loop_infinitely(client, mock_e2b_api_key_none):
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 2000
small_agent_state = client.create_agent(
name="super_small_context_agent",
llm_config=small_context_llm_config,
)
with pytest.raises(ContextWindowExceededError, match=f"Ran summarizer {summarizer_settings.max_summarizer_retries}"):
client.user_message(
agent_id=small_agent_state.id,
message="hi " * 1000,
)
client.delete_agent(small_agent_state.id)
def test_summarize_messages_inplace(client, agent_state, mock_e2b_api_key_none):
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
# First send a few messages (5)
@@ -134,7 +234,7 @@ def test_auto_summarize(client, mock_e2b_api_key_none):
# "gemini-pro.json", TODO: Gemini is broken
],
)
def test_summarizer(config_filename):
def test_summarizer(config_filename, client, agent_state):
namespace = uuid.NAMESPACE_DNS
agent_name = str(uuid.uuid5(namespace, f"integration-test-summarizer-{config_filename}"))
@@ -175,6 +275,6 @@ def test_summarizer(config_filename):
)
# Invoke a summarize
letta_agent.summarize_messages_inplace(preserve_last_N_messages=False)
letta_agent.summarize_messages_inplace()
in_context_messages = client.get_in_context_messages(agent_state.id)
assert SUMMARY_KEY_PHRASE in in_context_messages[1].text, f"Test failed for config: {config_filename}"