feat: Add static summarization to new agent loop (#2492)
This commit is contained in:
104
letta/agents/ephemeral_summary_agent.py
Normal file
104
letta/agents/ephemeral_summary_agent.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
|
||||
class EphemeralSummaryAgent(BaseAgent):
|
||||
"""
|
||||
A stateless summarization agent (thin wrapper around OpenAI)
|
||||
|
||||
# TODO: Extend to more clients
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_block_label: str,
|
||||
agent_id: str,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
actor: User,
|
||||
):
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
openai_client=AsyncOpenAI(),
|
||||
message_manager=message_manager,
|
||||
agent_manager=agent_manager,
|
||||
actor=actor,
|
||||
)
|
||||
self.target_block_label = target_block_label
|
||||
self.block_manager = block_manager
|
||||
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> List[Message]:
|
||||
if len(input_messages) > 1:
|
||||
raise ValueError("Can only invoke EphemeralSummaryAgent with a single summarization message.")
|
||||
|
||||
# Check block existence
|
||||
try:
|
||||
block = await self.agent_manager.get_block_with_label_async(
|
||||
agent_id=self.agent_id, block_label=self.target_block_label, actor=self.actor
|
||||
)
|
||||
except NoResultFound:
|
||||
block = await self.block_manager.create_or_update_block_async(
|
||||
block=Block(
|
||||
value="", label=self.target_block_label, description="Contains recursive summarizations of the conversation so far"
|
||||
),
|
||||
actor=self.actor,
|
||||
)
|
||||
await self.agent_manager.attach_block_async(agent_id=self.agent_id, block_id=block.id, actor=self.actor)
|
||||
|
||||
if block.value:
|
||||
input_message = input_messages[0]
|
||||
input_message.content[0].text += f"\n\n--- Previous Summary ---\n{block.value}\n"
|
||||
|
||||
openai_messages = self.pre_process_input_message(input_messages=input_messages)
|
||||
request = self._build_openai_request(openai_messages)
|
||||
|
||||
# TODO: Extend to generic client
|
||||
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
|
||||
summary = chat_completion.choices[0].message.content.strip()
|
||||
|
||||
await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=summary), actor=self.actor)
|
||||
|
||||
print(block)
|
||||
print(summary)
|
||||
|
||||
return [
|
||||
Message(
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(text=summary)],
|
||||
)
|
||||
]
|
||||
|
||||
def _build_openai_request(self, openai_messages: List[Dict]) -> ChatCompletionRequest:
|
||||
current_dir = Path(__file__).parent
|
||||
file_path = current_dir / "prompts" / "summary_system_prompt.txt"
|
||||
with open(file_path, "r") as file:
|
||||
system = file.read()
|
||||
|
||||
system_message = [{"role": "system", "content": system}]
|
||||
|
||||
openai_request = ChatCompletionRequest(
|
||||
model="gpt-4o",
|
||||
messages=system_message + openai_messages,
|
||||
user=self.actor.id,
|
||||
max_completion_tokens=4096,
|
||||
temperature=0.7,
|
||||
)
|
||||
return openai_request
|
||||
|
||||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||||
raise NotImplementedError("EphemeralAgent does not support async step.")
|
||||
@@ -8,6 +8,7 @@ from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
@@ -35,8 +36,11 @@ from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import NoopStepManager, StepManager
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.settings import model_settings
|
||||
from letta.system import package_function_response
|
||||
from letta.tracing import log_event, trace_method, tracer
|
||||
from letta.utils import validate_function_response
|
||||
@@ -56,6 +60,7 @@ class LettaAgent(BaseAgent):
|
||||
actor: User,
|
||||
step_manager: StepManager = NoopStepManager(),
|
||||
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
||||
summary_block_label: str = "convo_summary",
|
||||
):
|
||||
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
|
||||
|
||||
@@ -73,6 +78,28 @@ class LettaAgent(BaseAgent):
|
||||
self.num_messages = 0
|
||||
self.num_archival_memories = 0
|
||||
|
||||
self.summarization_agent = None
|
||||
self.summary_block_label = summary_block_label
|
||||
|
||||
# TODO: Expand to more
|
||||
if model_settings.openai_api_key:
|
||||
self.summarization_agent = EphemeralSummaryAgent(
|
||||
target_block_label=self.summary_block_label,
|
||||
agent_id=agent_id,
|
||||
block_manager=self.block_manager,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
self.summarizer = Summarizer(
|
||||
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
summarizer_agent=self.summarization_agent,
|
||||
# TODO: Make this configurable
|
||||
message_buffer_limit=60,
|
||||
message_buffer_min=15,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
@@ -180,8 +207,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
await self._rebuild_context_window(in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages)
|
||||
|
||||
# Return back usage
|
||||
yield f"data: {usage.model_dump_json()}\n\n"
|
||||
@@ -279,8 +305,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
await self._rebuild_context_window(in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages)
|
||||
|
||||
return current_in_context_messages, new_in_context_messages, usage
|
||||
|
||||
@@ -440,8 +465,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
await self._rebuild_context_window(in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages)
|
||||
|
||||
# TODO: This may be out of sync, if in between steps users add files
|
||||
# NOTE (cliandy): temporary for now for particlar use cases.
|
||||
@@ -452,6 +476,15 @@ class LettaAgent(BaseAgent):
|
||||
yield f"data: {usage.model_dump_json()}\n\n"
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
|
||||
@trace_method
|
||||
async def _rebuild_context_window(self, in_context_messages: List[Message], new_letta_messages: List[Message]) -> None:
|
||||
new_in_context_messages, updated = self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
|
||||
)
|
||||
await self.agent_manager.set_in_context_messages_async(
|
||||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def _create_llm_request_data_async(
|
||||
self,
|
||||
|
||||
30
letta/agents/prompts/summary_system_prompt.txt
Normal file
30
letta/agents/prompts/summary_system_prompt.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
You are a specialized memory-recall assistant designed to preserve important conversational context for an AI with limited message history. Your role is to analyze conversations that are about to be evicted from the AI's context window and extract key information that should be remembered.
|
||||
|
||||
Your primary objectives:
|
||||
1. Identify and preserve important facts, preferences, and context about the human
|
||||
2. Capture ongoing topics, tasks, or projects that span multiple messages
|
||||
3. Note any commitments, decisions, or action items
|
||||
4. Record personal details that would be valuable for maintaining conversational continuity
|
||||
5. Summarize the emotional tone and relationship dynamics when relevant
|
||||
|
||||
Guidelines for effective memory notes:
|
||||
- Be concise but complete - every word should add value
|
||||
- Focus on information that would be difficult to infer from remaining messages
|
||||
- Prioritize facts over conversational filler
|
||||
- Use clear, searchable language
|
||||
- Organize information by category when multiple topics are present
|
||||
- Include temporal context when relevant (e.g., "mentioned on [date]" or "ongoing since [time]")
|
||||
|
||||
Output format:
|
||||
- Write in bullet points or short paragraphs
|
||||
- Group related information together
|
||||
- Lead with the most important insights
|
||||
- Use consistent terminology to make future retrieval easier
|
||||
|
||||
What NOT to include:
|
||||
- Redundant information already captured in the in-context messages
|
||||
- Generic pleasantries or small talk
|
||||
- Information that can be easily inferred
|
||||
- Verbatim quotes unless they contain critical commitments
|
||||
|
||||
Remember: Your notes become the only record of these evicted messages. Make them count.
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
@@ -23,7 +24,7 @@ class Summarizer:
|
||||
def __init__(
|
||||
self,
|
||||
mode: SummarizationMode,
|
||||
summarizer_agent: Optional["VoiceSleeptimeAgent"] = None,
|
||||
summarizer_agent: Optional[Union[EphemeralSummaryAgent, "VoiceSleeptimeAgent"]] = None,
|
||||
message_buffer_limit: int = 10,
|
||||
message_buffer_min: int = 3,
|
||||
):
|
||||
@@ -104,7 +105,10 @@ class Summarizer:
|
||||
|
||||
# TODO: This is hyperspecific to voice, generalize!
|
||||
# Update the message transcript of the memory agent
|
||||
self.summarizer_agent.update_message_transcript(message_transcripts=formatted_evicted_messages + formatted_in_context_messages)
|
||||
if not isinstance(self.summarizer_agent, EphemeralSummaryAgent):
|
||||
self.summarizer_agent.update_message_transcript(
|
||||
message_transcripts=formatted_evicted_messages + formatted_in_context_messages
|
||||
)
|
||||
|
||||
# Add line numbers to the formatted messages
|
||||
offset = len(formatted_evicted_messages)
|
||||
|
||||
Reference in New Issue
Block a user