feat: Rework summarizer (#654)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
@@ -8,9 +9,13 @@ import pytest
|
||||
from letta import create_client
|
||||
from letta.agent import Agent
|
||||
from letta.client.client import LocalClient
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
|
||||
from tests.helpers.utils import cleanup
|
||||
@@ -44,6 +49,101 @@ def agent_state(client):
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
# Sample data setup
|
||||
def generate_message(role: str, text: str = None, tool_calls: List = None) -> Message:
|
||||
"""Helper to generate a Message object."""
|
||||
return Message(
|
||||
id="message-" + str(uuid.uuid4()),
|
||||
role=MessageRole(role),
|
||||
text=text or f"{role} message text",
|
||||
created_at=datetime.utcnow(),
|
||||
tool_calls=tool_calls or [],
|
||||
)
|
||||
|
||||
|
||||
def test_cutoff_calculation(mocker):
|
||||
"""Test basic scenarios where the function calculates the cutoff correctly."""
|
||||
# Arrange
|
||||
logger = mocker.Mock() # Mock logger
|
||||
messages = [
|
||||
generate_message("system"),
|
||||
generate_message("user"),
|
||||
generate_message("assistant"),
|
||||
generate_message("user"),
|
||||
generate_message("assistant"),
|
||||
]
|
||||
mocker.patch("letta.settings.summarizer_settings.desired_memory_token_pressure", 0.5)
|
||||
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", False)
|
||||
|
||||
# Basic tests
|
||||
token_counts = [4, 2, 8, 2, 2]
|
||||
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
|
||||
assert cutoff == 3
|
||||
assert messages[cutoff - 1].role == MessageRole.assistant
|
||||
|
||||
token_counts = [4, 2, 2, 2, 2]
|
||||
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
|
||||
assert cutoff == 5
|
||||
assert messages[cutoff - 1].role == MessageRole.assistant
|
||||
|
||||
token_counts = [2, 2, 3, 2, 2]
|
||||
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
|
||||
assert cutoff == 3
|
||||
assert messages[cutoff - 1].role == MessageRole.assistant
|
||||
|
||||
# Evict all messages
|
||||
# Should give the end of the token_counts, even though it is not necessary (can just evict up to the 100)
|
||||
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", True)
|
||||
token_counts = [1, 1, 100, 1, 1]
|
||||
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
|
||||
assert cutoff == 5
|
||||
assert messages[cutoff - 1].role == MessageRole.assistant
|
||||
|
||||
# Don't evict all messages with same token_counts, cutoff now should be at the 100
|
||||
# Should give the end of the token_counts, even though it is not necessary (can just evict up to the 100)
|
||||
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", False)
|
||||
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
|
||||
assert cutoff == 3
|
||||
assert messages[cutoff - 1].role == MessageRole.assistant
|
||||
|
||||
# Set `keep_last_n_messages`
|
||||
mocker.patch("letta.settings.summarizer_settings.keep_last_n_messages", 3)
|
||||
token_counts = [4, 2, 2, 2, 2]
|
||||
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
|
||||
assert cutoff == 2
|
||||
assert messages[cutoff - 1].role == MessageRole.user
|
||||
|
||||
|
||||
def test_summarize_many_messages_basic(client, mock_e2b_api_key_none):
|
||||
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
|
||||
small_context_llm_config.context_window = 3000
|
||||
small_agent_state = client.create_agent(
|
||||
name="small_context_agent",
|
||||
llm_config=small_context_llm_config,
|
||||
)
|
||||
for _ in range(10):
|
||||
client.user_message(
|
||||
agent_id=small_agent_state.id,
|
||||
message="hi " * 60,
|
||||
)
|
||||
client.delete_agent(small_agent_state.id)
|
||||
|
||||
|
||||
def test_summarize_large_message_does_not_loop_infinitely(client, mock_e2b_api_key_none):
|
||||
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
|
||||
small_context_llm_config.context_window = 2000
|
||||
small_agent_state = client.create_agent(
|
||||
name="super_small_context_agent",
|
||||
llm_config=small_context_llm_config,
|
||||
)
|
||||
with pytest.raises(ContextWindowExceededError, match=f"Ran summarizer {summarizer_settings.max_summarizer_retries}"):
|
||||
client.user_message(
|
||||
agent_id=small_agent_state.id,
|
||||
message="hi " * 1000,
|
||||
)
|
||||
client.delete_agent(small_agent_state.id)
|
||||
|
||||
|
||||
def test_summarize_messages_inplace(client, agent_state, mock_e2b_api_key_none):
|
||||
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
|
||||
# First send a few messages (5)
|
||||
@@ -134,7 +234,7 @@ def test_auto_summarize(client, mock_e2b_api_key_none):
|
||||
# "gemini-pro.json", TODO: Gemini is broken
|
||||
],
|
||||
)
|
||||
def test_summarizer(config_filename):
|
||||
def test_summarizer(config_filename, client, agent_state):
|
||||
namespace = uuid.NAMESPACE_DNS
|
||||
agent_name = str(uuid.uuid5(namespace, f"integration-test-summarizer-{config_filename}"))
|
||||
|
||||
@@ -175,6 +275,6 @@ def test_summarizer(config_filename):
|
||||
)
|
||||
|
||||
# Invoke a summarize
|
||||
letta_agent.summarize_messages_inplace(preserve_last_N_messages=False)
|
||||
letta_agent.summarize_messages_inplace()
|
||||
in_context_messages = client.get_in_context_messages(agent_state.id)
|
||||
assert SUMMARY_KEY_PHRASE in in_context_messages[1].text, f"Test failed for config: {config_filename}"
|
||||
|
||||
Reference in New Issue
Block a user