test: Add tests for summarizer enumeration (#1952)

This commit is contained in:
Matthew Zhou
2025-04-30 16:26:49 -07:00
committed by GitHub
parent 0b060b88aa
commit 1bf82a1e7c
2 changed files with 111 additions and 33 deletions

View File

@@ -4,6 +4,7 @@ import traceback
from typing import List, Tuple
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
@@ -77,7 +78,7 @@ class Summarizer:
logger.info("Buffer length hit, evicting messages.")
target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1
target_trim_index = len(all_in_context_messages) - self.message_buffer_min
while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user:
target_trim_index += 1
@@ -112,11 +113,12 @@ class Summarizer:
summary_request_text = f"""Youre a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
(Older) Evicted Messages:\n
{evicted_messages_str}
{evicted_messages_str}\n
(Newer) In-Context Messages:\n
{in_context_messages_str}
"""
print(summary_request_text)
# Fire-and-forget the summarization task
self.fire_and_forget(
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
@@ -149,6 +151,9 @@ def format_transcript(messages: List[Message], include_system: bool = False) ->
# 1) Try plain content
if msg.content:
# Skip tool messages where the name is "send_message"
if msg.role == MessageRole.tool and msg.name == DEFAULT_MESSAGE_TOOL:
continue
text = "".join(c.text for c in msg.content).strip()
# 2) Otherwise, try extracting from function calls
@@ -156,11 +161,14 @@ def format_transcript(messages: List[Message], include_system: bool = False) ->
parts = []
for call in msg.tool_calls:
args_str = call.function.arguments
try:
args = json.loads(args_str)
# pull out a "message" field if present
parts.append(args.get("message", args_str))
except json.JSONDecodeError:
if call.function.name == DEFAULT_MESSAGE_TOOL:
try:
args = json.loads(args_str)
# pull out a "message" field if present
parts.append(args.get(DEFAULT_MESSAGE_TOOL_KWARG, args_str))
except json.JSONDecodeError:
parts.append(args_str)
else:
parts.append(args_str)
text = " ".join(parts).strip()

View File

@@ -1,16 +1,15 @@
import os
import threading
from unittest.mock import MagicMock
import pytest
from dotenv import load_dotenv
from letta_client import Letta
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionChunk
from sqlalchemy import delete
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.config import LettaConfig
from letta.orm import Provider, Step
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentType, CreateAgent
from letta.schemas.block import CreateBlock
@@ -20,7 +19,7 @@ from letta.schemas.group import ManagerType
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
from letta.schemas.tool import ToolCreate
@@ -29,6 +28,8 @@ from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import get_persona_text
@@ -48,16 +49,24 @@ MESSAGE_TRANSCRIPTS = [
"user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.",
"assistant: How about Vegan Treats in Santa Barbara? Theyre highly rated.",
"user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.",
"user: I want to make sure my itinerary isnt too tight—aiming for 34 days total.",
"assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.",
"user: Yes, lets do that.",
"assistant: Ill put together a day-by-day plan now.",
]
SUMMARY_REQ_TEXT = """
Here is the conversation history. Lines marked (Older) are about to be evicted; lines marked (Newer) are still in context for clarity:
SYSTEM_MESSAGE = Message(role=MessageRole.system, content=[TextContent(text="System message")])
MESSAGE_OBJECTS = [SYSTEM_MESSAGE]
for entry in MESSAGE_TRANSCRIPTS:
role_str, text = entry.split(":", 1)
role = MessageRole.user if role_str.strip() == "user" else MessageRole.assistant
MESSAGE_OBJECTS.append(Message(role=role, content=[TextContent(text=text.strip())]))
MESSAGE_EVICT_BREAKPOINT = 14
SUMMARY_REQ_TEXT = """
Youre a memory-recall helper for an AI that can only keep the last 4 messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
(Older) Evicted Messages:
(Older)
0. user: Hey, Ive been thinking about planning a road trip up the California coast next month.
1. assistant: That sounds amazing! Do you have any particular cities or sights in mind?
2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops.
@@ -70,16 +79,13 @@ Here is the conversation history. Lines marked (Older) are about to be evicted;
9. assistant: Happy early birthday! Would you like gift ideas or celebration tips?
10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.
11. assistant: How about Vegan Treats in Santa Barbara? Theyre highly rated.
(Newer) In-Context Messages:
12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.
(Newer)
13. user: I want to make sure my itinerary isnt too tight—aiming for 34 days total.
14. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
15. user: Yes, lets do that.
16. assistant: Ill put together a day-by-day plan now.
Please segment the (Older) portion into coherent chunks and—using **only** the `store_memory` tool—output a JSON call that lists each chunks `start_index`, `end_index`, and a one-sentence `contextual_description`.
"""
13. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
14. user: Yes, lets do that.
15. assistant: Ill put together a day-by-day plan now."""
# --- Server Management --- #
@@ -214,22 +220,12 @@ def org_id(server):
yield org.id
# cleanup
with server.organization_manager.session_maker() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()
server.organization_manager.delete_organization_by_id(org.id)
@pytest.fixture(scope="module")
def actor(server, org_id):
user = server.user_manager.create_default_user()
yield user
# cleanup
server.user_manager.delete_user_by_id(user.id)
# --- Helper Functions --- #
@@ -301,6 +297,80 @@ async def test_multiple_messages(disable_e2b_api_key, client, voice_agent, endpo
print(chunk.choices[0].delta.content)
@pytest.mark.asyncio
async def test_summarization(disable_e2b_api_key, voice_agent):
agent_manager = AgentManager()
user_manager = UserManager()
actor = user_manager.get_default_user()
request = CreateAgent(
name=voice_agent.name + "-sleeptime",
agent_type=AgentType.voice_sleeptime_agent,
block_ids=[block.id for block in voice_agent.memory.blocks],
memory_blocks=[
CreateBlock(
label="memory_persona",
value=get_persona_text("voice_memory_persona"),
),
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
project_id=voice_agent.project_id,
)
sleeptime_agent = agent_manager.create_agent(request, actor=actor)
async_client = AsyncOpenAI()
memory_agent = VoiceSleeptimeAgent(
agent_id=sleeptime_agent.id,
convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent
openai_client=async_client,
message_manager=MessageManager(),
agent_manager=agent_manager,
actor=actor,
block_manager=BlockManager(),
target_block_label="human",
message_transcripts=MESSAGE_TRANSCRIPTS,
)
summarizer = Summarizer(
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
summarizer_agent=memory_agent,
message_buffer_limit=8,
message_buffer_min=4,
)
# stub out the agent.step so it returns a known sentinel
memory_agent.step = MagicMock(return_value="STEP_RESULT")
# patch fire_and_forget on *this* summarizer instance to a MagicMock
summarizer.fire_and_forget = MagicMock()
# now call the method under test
in_ctx = MESSAGE_OBJECTS[:MESSAGE_EVICT_BREAKPOINT]
new_msgs = MESSAGE_OBJECTS[MESSAGE_EVICT_BREAKPOINT:]
# call under test (this is sync)
updated, did_summarize = summarizer._static_buffer_summarization(
in_context_messages=in_ctx,
new_letta_messages=new_msgs,
)
assert did_summarize is True
assert len(updated) == summarizer.message_buffer_min + 1 # One extra for system message
assert updated[0].role == MessageRole.system # Preserved system message
# 2) the summarizer_agent.step() should have been *called* exactly once
memory_agent.step.assert_called_once()
call_args = memory_agent.step.call_args.args[0] # the single positional argument: a list of MessageCreate
assert isinstance(call_args, list)
assert isinstance(call_args[0], MessageCreate)
assert call_args[0].role == MessageRole.user
assert "15. assistant: Ill put together a day-by-day plan now." in call_args[0].content[0].text
# 3) fire_and_forget should have been called once, and its argument must be the coroutine returned by step()
summarizer.fire_and_forget.assert_called_once()
@pytest.mark.asyncio
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
"""Tests chat completion streaming using the Async OpenAI client."""