diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index b138bd98..efbadea3 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -4,6 +4,7 @@ import traceback from typing import List, Tuple from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent @@ -77,7 +78,7 @@ class Summarizer: logger.info("Buffer length hit, evicting messages.") - target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1 + target_trim_index = len(all_in_context_messages) - self.message_buffer_min while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user: target_trim_index += 1 @@ -112,11 +113,12 @@ class Summarizer: summary_request_text = f"""You’re a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost. (Older) Evicted Messages:\n -{evicted_messages_str} +{evicted_messages_str}\n (Newer) In-Context Messages:\n {in_context_messages_str} """ + print(summary_request_text) # Fire-and-forget the summarization task self.fire_and_forget( self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])]) @@ -149,6 +151,9 @@ def format_transcript(messages: List[Message], include_system: bool = False) -> # 1) Try plain content if msg.content: + # Skip tool messages where the name is "send_message" + if msg.role == MessageRole.tool and msg.name == DEFAULT_MESSAGE_TOOL: + continue text = "".join(c.text for c in msg.content).strip() # 2) Otherwise, try extracting from function calls @@ -156,11 +161,14 @@ def format_transcript(messages: List[Message], include_system: bool = False) -> parts = [] for call in msg.tool_calls: args_str = call.function.arguments - try: - args = json.loads(args_str) - # pull out a "message" field if present - parts.append(args.get("message", args_str)) - except json.JSONDecodeError: + if call.function.name == DEFAULT_MESSAGE_TOOL: + try: + args = json.loads(args_str) + # pull out a "message" field if present + parts.append(args.get(DEFAULT_MESSAGE_TOOL_KWARG, args_str)) + except json.JSONDecodeError: + parts.append(args_str) + else: parts.append(args_str) text = " ".join(parts).strip() diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index 27109116..f4533395 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -1,16 +1,15 @@ import os import threading +from unittest.mock import MagicMock import pytest from dotenv import load_dotenv from letta_client import Letta from openai import AsyncOpenAI from openai.types.chat import ChatCompletionChunk -from sqlalchemy import delete from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent from letta.config import LettaConfig -from letta.orm import Provider, Step from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentType, CreateAgent from letta.schemas.block import CreateBlock @@ -20,7 +19,7 @@ from letta.schemas.group import ManagerType from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import MessageCreate +from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage from letta.schemas.tool import ToolCreate @@ -29,6 +28,8 @@ from letta.server.server import SyncServer from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.message_manager import MessageManager +from letta.services.summarizer.enums import SummarizationMode +from letta.services.summarizer.summarizer import Summarizer from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.utils import get_persona_text @@ -48,16 +49,24 @@ MESSAGE_TRANSCRIPTS = [ "user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.", "assistant: How about Vegan Treats in Santa Barbara? They’re highly rated.", "user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.", - "user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total.", "assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.", "user: Yes, let’s do that.", "assistant: I’ll put together a day-by-day plan now.", ] -SUMMARY_REQ_TEXT = """ -Here is the conversation history. Lines marked (Older) are about to be evicted; lines marked (Newer) are still in context for clarity: +SYSTEM_MESSAGE = Message(role=MessageRole.system, content=[TextContent(text="System message")]) +MESSAGE_OBJECTS = [SYSTEM_MESSAGE] +for entry in MESSAGE_TRANSCRIPTS: + role_str, text = entry.split(":", 1) + role = MessageRole.user if role_str.strip() == "user" else MessageRole.assistant + MESSAGE_OBJECTS.append(Message(role=role, content=[TextContent(text=text.strip())])) +MESSAGE_EVICT_BREAKPOINT = 14 + +SUMMARY_REQ_TEXT = """ +You’re a memory-recall helper for an AI that can only keep the last 4 messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost. + +(Older) Evicted Messages: -(Older) 0. user: Hey, I’ve been thinking about planning a road trip up the California coast next month. 1. assistant: That sounds amazing! Do you have any particular cities or sights in mind? 2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops. @@ -70,16 +79,13 @@ Here is the conversation history. Lines marked (Older) are about to be evicted; 9. assistant: Happy early birthday! Would you like gift ideas or celebration tips? 10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat. 11. assistant: How about Vegan Treats in Santa Barbara? They’re highly rated. + +(Newer) In-Context Messages: + 12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro. - -(Newer) -13. user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total. -14. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops. -15. user: Yes, let’s do that. -16. assistant: I’ll put together a day-by-day plan now. - -Please segment the (Older) portion into coherent chunks and—using **only** the `store_memory` tool—output a JSON call that lists each chunk’s `start_index`, `end_index`, and a one-sentence `contextual_description`. - """ +13. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops. +14. user: Yes, let’s do that. +15. assistant: I’ll put together a day-by-day plan now.""" # --- Server Management --- # @@ -214,22 +220,12 @@ def org_id(server): yield org.id - # cleanup - with server.organization_manager.session_maker() as session: - session.execute(delete(Step)) - session.execute(delete(Provider)) - session.commit() - server.organization_manager.delete_organization_by_id(org.id) - @pytest.fixture(scope="module") def actor(server, org_id): user = server.user_manager.create_default_user() yield user - # cleanup - server.user_manager.delete_user_by_id(user.id) - # --- Helper Functions --- # @@ -301,6 +297,80 @@ async def test_multiple_messages(disable_e2b_api_key, client, voice_agent, endpo print(chunk.choices[0].delta.content) +@pytest.mark.asyncio +async def test_summarization(disable_e2b_api_key, voice_agent): + agent_manager = AgentManager() + user_manager = UserManager() + actor = user_manager.get_default_user() + + request = CreateAgent( + name=voice_agent.name + "-sleeptime", + agent_type=AgentType.voice_sleeptime_agent, + block_ids=[block.id for block in voice_agent.memory.blocks], + memory_blocks=[ + CreateBlock( + label="memory_persona", + value=get_persona_text("voice_memory_persona"), + ), + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + project_id=voice_agent.project_id, + ) + sleeptime_agent = agent_manager.create_agent(request, actor=actor) + + async_client = AsyncOpenAI() + + memory_agent = VoiceSleeptimeAgent( + agent_id=sleeptime_agent.id, + convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent + openai_client=async_client, + message_manager=MessageManager(), + agent_manager=agent_manager, + actor=actor, + block_manager=BlockManager(), + target_block_label="human", + message_transcripts=MESSAGE_TRANSCRIPTS, + ) + + summarizer = Summarizer( + mode=SummarizationMode.STATIC_MESSAGE_BUFFER, + summarizer_agent=memory_agent, + message_buffer_limit=8, + message_buffer_min=4, + ) + + # stub out the agent.step so it returns a known sentinel + memory_agent.step = MagicMock(return_value="STEP_RESULT") + + # patch fire_and_forget on *this* summarizer instance to a MagicMock + summarizer.fire_and_forget = MagicMock() + + # now call the method under test + in_ctx = MESSAGE_OBJECTS[:MESSAGE_EVICT_BREAKPOINT] + new_msgs = MESSAGE_OBJECTS[MESSAGE_EVICT_BREAKPOINT:] + # call under test (this is sync) + updated, did_summarize = summarizer._static_buffer_summarization( + in_context_messages=in_ctx, + new_letta_messages=new_msgs, + ) + + assert did_summarize is True + assert len(updated) == summarizer.message_buffer_min + 1 # One extra for system message + assert updated[0].role == MessageRole.system # Preserved system message + + # 2) the summarizer_agent.step() should have been *called* exactly once + memory_agent.step.assert_called_once() + call_args = memory_agent.step.call_args.args[0] # the single positional argument: a list of MessageCreate + assert isinstance(call_args, list) + assert isinstance(call_args[0], MessageCreate) + assert call_args[0].role == MessageRole.user + assert "15. assistant: I’ll put together a day-by-day plan now." in call_args[0].content[0].text + + # 3) fire_and_forget should have been called once, and its argument must be the coroutine returned by step() + summarizer.fire_and_forget.assert_called_once() + + @pytest.mark.asyncio async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent): """Tests chat completion streaming using the Async OpenAI client."""