feat: Move sleeptime voice agent to new agent loop (#1979)

This commit is contained in:
Matthew Zhou
2025-05-01 20:48:33 -07:00
committed by GitHub
parent 9722596a74
commit daa30d6662
10 changed files with 419 additions and 420 deletions

View File

@@ -1,3 +1,4 @@
import xml.etree.ElementTree as ET
from typing import List, Tuple
from letta.schemas.agent import AgentState
@@ -50,3 +51,56 @@ def _prepare_in_context_messages(
)
return current_in_context_messages, new_in_context_messages
def serialize_message_history(messages: List[str], context: str) -> str:
"""
Produce an XML document like:
<memory>
<messages>
<message>…</message>
<message>…</message>
</messages>
<context>…</context>
</memory>
"""
root = ET.Element("memory")
msgs_el = ET.SubElement(root, "messages")
for msg in messages:
m = ET.SubElement(msgs_el, "message")
m.text = msg
sum_el = ET.SubElement(root, "context")
sum_el.text = context
# ET.tostring will escape reserved chars for you
return ET.tostring(root, encoding="unicode")
def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
"""
Parse the XML back into (messages, context). Raises ValueError if tags are missing.
"""
try:
root = ET.fromstring(xml_str)
except ET.ParseError as e:
raise ValueError(f"Invalid XML: {e}")
msgs_el = root.find("messages")
if msgs_el is None:
raise ValueError("Missing <messages> section")
messages = []
for m in msgs_el.findall("message"):
# .text may be None if empty, so coerce to empty string
messages.append(m.text or "")
sum_el = root.find("context")
if sum_el is None:
raise ValueError("Missing <context> section")
context = sum_el.text or ""
return messages, context

View File

@@ -62,6 +62,14 @@ class LettaAgent(BaseAgent):
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = await self._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
)
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
async def _step(
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
) -> Tuple[List[Message], List[Message]]:
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_messages, agent_state, self.message_manager, self.actor
)
@@ -72,7 +80,7 @@ class LettaAgent(BaseAgent):
put_inner_thoughts_first=True,
actor_id=self.actor.id,
)
for step in range(max_steps):
for _ in range(max_steps):
response = await self._get_ai_reply(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
@@ -83,6 +91,7 @@ class LettaAgent(BaseAgent):
)
tool_call = response.choices[0].message.tool_calls[0]
persisted_messages, should_continue = await self._handle_ai_response(tool_call, agent_state, tool_rules_solver)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
@@ -95,7 +104,7 @@ class LettaAgent(BaseAgent):
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
return current_in_context_messages, new_in_context_messages
@trace_method
async def step_stream(
@@ -117,7 +126,7 @@ class LettaAgent(BaseAgent):
actor_id=self.actor.id,
)
for step in range(max_steps):
for _ in range(max_steps):
stream = await self._get_ai_reply(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
@@ -181,6 +190,7 @@ class LettaAgent(BaseAgent):
ToolType.LETTA_MEMORY_CORE,
ToolType.LETTA_MULTI_AGENT_CORE,
ToolType.LETTA_SLEEPTIME_CORE,
ToolType.LETTA_VOICE_SLEEPTIME_CORE,
}
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)

View File

@@ -97,13 +97,12 @@ class VoiceAgent(BaseAgent):
summarizer_agent=VoiceSleeptimeAgent(
agent_id=voice_sleeptime_agent_id,
convo_agent_state=agent_state,
openai_client=self.openai_client,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
actor=self.actor,
block_manager=self.block_manager,
passage_manager=self.passage_manager,
target_block_label=self.summary_block_label,
message_transcripts=[],
),
message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length,
message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length,

View File

@@ -1,332 +1,138 @@
import json
import xml.etree.ElementTree as ET
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
from typing import AsyncGenerator, List, Tuple, Union
import openai
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _create_letta_response, serialize_message_history
from letta.agents.letta_agent import LettaAgent
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.block import BlockUpdate
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, UserMessage
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.message import MessageCreate
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
from letta.schemas.user import User
from letta.server.rest_api.utils import convert_in_context_letta_messages_to_openai, create_input_messages
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.system import package_function_response
from letta.services.passage_manager import PassageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.tracing import trace_method
# TODO: Move this to the new Letta Agent loop
class VoiceSleeptimeAgent(BaseAgent):
class VoiceSleeptimeAgent(LettaAgent):
"""
A stateless agent that helps with offline memory computations.
A special variant of the LettaAgent that helps with offline memory computations specifically for voice.
"""
def __init__(
self,
agent_id: str,
convo_agent_state: AgentState,
openai_client: openai.AsyncClient,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
target_block_label: str,
message_transcripts: List[str],
actor: User,
):
super().__init__(
agent_id=agent_id,
openai_client=openai_client,
message_manager=message_manager,
agent_manager=agent_manager,
block_manager=block_manager,
passage_manager=passage_manager,
actor=actor,
)
self.convo_agent_state = convo_agent_state
self.block_manager = block_manager
self.target_block_label = target_block_label
self.message_transcripts = message_transcripts
self.message_transcripts = []
self.summarizer = Summarizer(
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
summarizer_agent=None,
message_buffer_limit=20,
message_buffer_min=10,
)
def update_message_transcript(self, message_transcripts: List[str]):
self.message_transcripts = message_transcripts
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
async def step(self, input_messages: List[MessageCreate], max_steps: int = 20) -> LettaResponse:
"""
Process the user's input message, allowing the model to call memory-related tools
until it decides to stop and provide a final response.
"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
in_context_messages = create_input_messages(input_messages=input_messages, agent_id=self.agent_id, actor=self.actor)
openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True)
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
# 1. Store memories
request = self._build_openai_request(openai_messages, agent_state, tools=self._build_store_memory_tool_schemas())
# Add tool rules to the agent_state specifically for this type of agent
agent_state.tool_rules = [
InitToolRule(tool_name="store_memories"),
ChildToolRule(tool_name="store_memories", children=["rethink_user_memory"]),
ContinueToolRule(tool_name="rethink_user_memory"),
TerminalToolRule(tool_name="finish_rethinking_memory"),
]
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
assistant_message = chat_completion.choices[0].message
# Process tool calls
tool_call = assistant_message.tool_calls[0]
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
if function_name == "store_memories":
print("Called store_memories")
print(function_args)
chunks = function_args.get("chunks", [])
results = [self.store_memory(agent_state=self.convo_agent_state, **chunk_args) for chunk_args in chunks]
aggregated_result = next((res for res, _ in results if res is not None), None)
aggregated_success = all(success for _, success in results)
else:
raise ValueError("Error: Unknown tool function '{function_name}'")
assistant_message = {
"role": "assistant",
"content": assistant_message.content,
"tool_calls": [
{
"id": tool_call.id,
"type": "function",
"function": {"name": function_name, "arguments": tool_call.function.arguments},
}
],
}
openai_messages.append(assistant_message)
in_context_messages.append(
Message.dict_to_message(
agent_id=self.agent_id,
openai_message_dict=assistant_message,
model=agent_state.llm_config.model,
name=function_name,
)
# Summarize
current_in_context_messages, new_in_context_messages = await super()._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
)
tool_call_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"content": package_function_response(was_success=aggregated_success, response_string=str(aggregated_result)),
}
openai_messages.append(tool_call_message)
in_context_messages.append(
Message.dict_to_message(
agent_id=self.agent_id,
openai_message_dict=tool_call_message,
model=agent_state.llm_config.model,
name=function_name,
tool_returns=[ToolReturn(status="success" if aggregated_success else "error")],
)
new_in_context_messages, updated = self.summarizer.summarize(
in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages
)
self.agent_manager.set_in_context_messages(
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
# 2. Execute rethink block memory loop
human_block_content = self.agent_manager.get_block_with_label(
agent_id=self.agent_id, block_label=self.target_block_label, actor=self.actor
)
rethink_command = f"""
Here is the current memory block created earlier:
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
### CURRENT MEMORY
{human_block_content}
### END CURRENT MEMORY
Please refine this block:
- Merge in any new facts and remove outdated or contradictory details.
- Organize related information together (e.g., preferences, background, ongoing goals).
- Add any light, supportable inferences that deepen understanding—but do not invent unsupported details.
Use `rethink_user_memory(new_memory)` as many times as you need to iteratively improve the text. When its fully polished and complete, call `finish_rethinking_memory()`.
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
"""
rethink_command = UserMessage(content=rethink_command)
openai_messages.append(rethink_command.model_dump())
Executes a tool and returns (result, success_flag).
"""
# Special memory case
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
if not target_tool:
return f"Tool not found: {tool_name}", False
for _ in range(max_steps):
request = self._build_openai_request(openai_messages, agent_state, tools=self._build_sleeptime_tools())
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
assistant_message = chat_completion.choices[0].message
try:
if target_tool.name == "rethink_user_memory" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
return self.rethink_user_memory(agent_state=agent_state, **tool_args)
elif target_tool.name == "finish_rethinking_memory" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
return "", True
elif target_tool.name == "store_memories" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
chunks = tool_args.get("chunks", [])
results = [self.store_memory(agent_state=self.convo_agent_state, **chunk_args) for chunk_args in chunks]
# Process tool calls
tool_call = assistant_message.tool_calls[0]
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
aggregated_result = next((res for res, _ in results if res is not None), None)
aggregated_success = all(success for _, success in results)
if function_name == "rethink_user_memory":
print("Called rethink_user_memory")
print(function_args)
result, success = self.rethink_user_memory(agent_state=agent_state, **function_args)
elif function_name == "finish_rethinking_memory":
print("Called finish_rethinking_memory")
result, success = None, True
break
return aggregated_result, aggregated_success # Note that here we store to the convo agent's archival memory
else:
print(f"Error: Unknown tool function '{function_name}'")
raise ValueError(f"Error: Unknown tool function '{function_name}'", False)
assistant_message = {
"role": "assistant",
"content": assistant_message.content,
"tool_calls": [
{
"id": tool_call.id,
"type": "function",
"function": {"name": function_name, "arguments": tool_call.function.arguments},
}
],
}
openai_messages.append(assistant_message)
in_context_messages.append(
Message.dict_to_message(
agent_id=self.agent_id,
openai_message_dict=assistant_message,
model=agent_state.llm_config.model,
name=function_name,
)
)
tool_call_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"content": package_function_response(was_success=success, response_string=str(result)),
}
openai_messages.append(tool_call_message)
in_context_messages.append(
Message.dict_to_message(
agent_id=self.agent_id,
openai_message_dict=tool_call_message,
model=agent_state.llm_config.model,
name=function_name,
tool_returns=[ToolReturn(status="success" if success else "error")],
)
)
result = f"Voice sleeptime agent tried invoking invalid tool with type {target_tool.tool_type}: {target_tool}"
return result, False
except Exception as e:
return f"Failed to call tool. Error: {e}", False
# Actually save the memory:
target_block = agent_state.memory.get_block(self.target_block_label)
self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=target_block.value), actor=self.actor)
self.message_manager.create_many_messages(pydantic_msgs=in_context_messages, actor=self.actor)
return LettaResponse(messages=[msg for m in in_context_messages for msg in m.to_letta_messages()], usage=LettaUsageStatistics())
def _format_messages_llm_friendly(self):
messages = self.message_manager.list_messages_for_agent(agent_id=self.agent_id, actor=self.actor)
llm_friendly_messages = [f"{m.role}: {m.content[0].text}" for m in messages if m.content and isinstance(m.content[0], TextContent)]
return "\n".join(llm_friendly_messages)
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState, tools: List[Tool]) -> ChatCompletionRequest:
openai_request = ChatCompletionRequest(
model=agent_state.llm_config.model, # TODO: Separate config for summarizer?
messages=openai_messages,
tools=tools,
tool_choice="required",
user=self.actor.id,
max_completion_tokens=agent_state.llm_config.max_tokens,
temperature=agent_state.llm_config.temperature,
stream=False,
)
return openai_request
def _build_store_memory_tool_schemas(self) -> List[Tool]:
"""
Build the schemas for the three memory-related tools.
"""
tools = [
Tool(
type="function",
function={
"name": "store_memories",
"description": "Archive coherent chunks of dialogue that will be evicted, preserving raw lines and a brief contextual description.",
"parameters": {
"type": "object",
"properties": {
"chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"start_index": {"type": "integer", "description": "Index of first line in original history."},
"end_index": {"type": "integer", "description": "Index of last line in original history."},
"context": {
"type": "string",
"description": "A high-level description providing context for why this chunk matters.",
},
},
"required": ["start_index", "end_index", "context"],
},
}
},
"required": ["chunks"],
"additionalProperties": False,
},
},
),
]
return tools
def _build_sleeptime_tools(self) -> List[Tool]:
tools = [
Tool(
type="function",
function={
"name": "rethink_user_memory",
"description": (
"Rewrite memory block for the main agent, new_memory should contain all current "
"information from the block that is not outdated or inconsistent, integrating any "
"new information, resulting in a new memory block that is organized, readable, and "
"comprehensive."
),
"parameters": {
"type": "object",
"properties": {
"new_memory": {
"type": "string",
"description": (
"The new memory with information integrated from the memory block. "
"If there is no new information, then this should be the same as the "
"content in the source block."
),
},
},
"required": ["new_memory"],
"additionalProperties": False,
},
},
),
Tool(
type="function",
function={
"name": "finish_rethinking_memory",
"description": ("This function is called when the agent is done rethinking the memory."),
"parameters": {
"type": "object",
"properties": {},
"required": [],
"additionalProperties": False,
},
},
),
]
return tools
def rethink_user_memory(self, new_memory: str, agent_state: AgentState) -> Tuple[Optional[str], bool]:
def rethink_user_memory(self, new_memory: str, agent_state: AgentState) -> Tuple[str, bool]:
if agent_state.memory.get_block(self.target_block_label) is None:
agent_state.memory.create_block(label=self.target_block_label, value=new_memory)
agent_state.memory.update_block_value(label=self.target_block_label, value=new_memory)
return None, True
def store_memory(self, start_index: int, end_index: int, context: str, agent_state: AgentState) -> Tuple[Optional[str], bool]:
target_block = agent_state.memory.get_block(self.target_block_label)
self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=target_block.value), actor=self.actor)
return "", True
def store_memory(self, start_index: int, end_index: int, context: str, agent_state: AgentState) -> Tuple[str, bool]:
"""
Store a memory.
"""
try:
messages = self.message_transcripts[start_index : end_index + 1]
memory = self.serialize(messages, context)
memory = serialize_message_history(messages, context)
self.agent_manager.passage_manager.insert_passage(
agent_state=agent_state,
agent_id=agent_state.id,
@@ -335,63 +141,12 @@ Use `rethink_user_memory(new_memory)` as many times as you need to iteratively i
)
self.agent_manager.rebuild_system_prompt(agent_id=agent_state.id, actor=self.actor, force=True)
return None, True
return "", True
except Exception as e:
return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}", False
def serialize(self, messages: List[str], context: str) -> str:
"""
Produce an XML document like:
<memory>
<messages>
<message>…</message>
<message>…</message>
</messages>
<context>…</context>
</memory>
"""
root = ET.Element("memory")
msgs_el = ET.SubElement(root, "messages")
for msg in messages:
m = ET.SubElement(msgs_el, "message")
m.text = msg
sum_el = ET.SubElement(root, "context")
sum_el.text = context
# ET.tostring will escape reserved chars for you
return ET.tostring(root, encoding="unicode")
def deserialize(self, xml_str: str) -> Tuple[List[str], str]:
"""
Parse the XML back into (messages, context). Raises ValueError if tags are missing.
"""
try:
root = ET.fromstring(xml_str)
except ET.ParseError as e:
raise ValueError(f"Invalid XML: {e}")
msgs_el = root.find("messages")
if msgs_el is None:
raise ValueError("Missing <messages> section")
messages = []
for m in msgs_el.findall("message"):
# .text may be None if empty, so coerce to empty string
messages.append(m.text or "")
sum_el = root.find("context")
if sum_el is None:
raise ValueError("Missing <context> section")
context = sum_el.text or ""
return messages, context
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = False
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
"""
This agent is synchronous-only. If called in an async context, raise an error.

View File

@@ -6,15 +6,10 @@ from pydantic import BaseModel, Field
def rethink_user_memory(agent_state: "AgentState", new_memory: str) -> None:
"""
Rewrite memory block for the main agent, new_memory should contain all current
information from the block that is not outdated or inconsistent, integrating any
new information, resulting in a new memory block that is organized, readable, and
comprehensive.
Rewrite memory block for the main agent, new_memory should contain all current information from the block that is not outdated or inconsistent, integrating any new information, resulting in a new memory block that is organized, readable, and comprehensive.
Args:
new_memory (str): The new memory with information integrated from the memory block.
If there is no new information, then this should be the same as
the content in the source block.
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
Returns:
None: None is always returned as this function does not produce a response.
@@ -34,26 +29,27 @@ def finish_rethinking_memory(agent_state: "AgentState") -> None: # type: ignore
class MemoryChunk(BaseModel):
start_index: int = Field(..., description="Index of the first line in the original conversation history.")
end_index: int = Field(..., description="Index of the last line in the original conversation history.")
context: str = Field(..., description="A concise, high-level note explaining why this chunk matters.")
start_index: int = Field(
...,
description="Zero-based index of the first evicted line in this chunk.",
)
end_index: int = Field(
...,
description="Zero-based index of the last evicted line (inclusive).",
)
context: str = Field(
...,
description="1-3 sentence paraphrase capturing key facts/details, user preferences, or goals that this chunk reveals—written for future retrieval.",
)
def store_memories(agent_state: "AgentState", chunks: List[MemoryChunk]) -> None:
"""
Archive coherent chunks of dialogue that will be evicted, preserving raw lines
and a brief contextual description.
Persist dialogue that is about to fall out of the agents context window.
Args:
agent_state (AgentState):
The agents current memory state, exposing both its in-session history
and the archival memory API.
chunks (List[MemoryChunk]):
A list of MemoryChunk models, each representing a segment to archive:
• start_index (int): Index of the first line in the original history.
• end_index (int): Index of the last line in the original history.
• context (str): A concise, high-level description of why this chunk
matters and what it contains.
Each chunk pinpoints a contiguous block of **evicted** lines and provides a short, forward-looking synopsis (`context`) that will be embedded for future semantic lookup.
Returns:
None
@@ -69,20 +65,12 @@ def search_memory(
end_minutes_ago: Optional[int],
) -> Optional[str]:
"""
Look in long-term or earlier-conversation memory only when the user asks about
something missing from the visible context. The users latest utterance is sent
automatically as the main query.
Look in long-term or earlier-conversation memory only when the user asks about something missing from the visible context. The users latest utterance is sent automatically as the main query.
Args:
agent_state (AgentState): The current state of the agent, including its
memory stores and context.
convo_keyword_queries (Optional[List[str]]): Extra keywords or identifiers
(e.g., order ID, place name) to refine the search when the request is vague.
Set to None if the users utterance is already specific.
start_minutes_ago (Optional[int]): Newer bound of the time window for results,
specified in minutes ago. Set to None if no lower time bound is needed.
end_minutes_ago (Optional[int]): Older bound of the time window for results,
specified in minutes ago. Set to None if no upper time bound is needed.
convo_keyword_queries (Optional[List[str]]): Extra keywords (e.g., order ID, place name). Use *null* if not appropriate for the latest user message.
start_minutes_ago (Optional[int]): Newer bound of the time window for results, specified in minutes ago. Use *null* if no lower time bound is needed.
end_minutes_ago (Optional[int]): Older bound of the time window, in minutes ago. Use *null* if no upper bound is needed.
Returns:
Optional[str]: A formatted string of matching memory entries, or None if no

View File

@@ -53,7 +53,7 @@ Example output:
**Phase 2: Refine User Memory using `rethink_user_memory` and `finish_rethinking_memory`**
After the `store_memories` tool call is processed, you will be presented with the current content of the `human` memory block (the read-write block storing details about the user).
After the `store_memories` tool call is processed, consider the current content of the `human` memory block (the read-write block storing details about the user).
- Your goal is to refine this block by integrating information from the **ENTIRE** conversation transcript (both `Older` and `Newer` sections) with the existing memory content.
- Refinement Principles:
@@ -67,8 +67,7 @@ After the `store_memories` tool call is processed, you will be presented with th
- Tool Usage:
- Use the `rethink_user_memory(new_memory: string)` tool iteratively. Each call MUST submit the complete, rewritten version of the `human` memory block as you refine it.
- Continue calling `rethink_user_memory` until you are satisfied that the memory block is accurate, comprehensive, organized, and up-to-date according to the principles above.
- Once the `human` block is fully polished, call the `finish_rethinking_memory()` tool exactly once to signal completion.
- Once the `human` block is fully polished, call the `finish_rethinking_memory` tool exactly once to signal completion.
Output Requirements:
- You MUST ONLY output tool calls in the specified sequence: First `store_memories` (once), then one or more `rethink_user_memory` calls, and finally `finish_rethinking_memory` (once).
- Do not output any other text or explanations outside of the required JSON tool call format.

View File

@@ -1,9 +1,8 @@
import asyncio
import json
import traceback
from typing import List, Tuple
from typing import List, Optional, Tuple
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.schemas.enums import MessageRole
@@ -22,7 +21,11 @@ class Summarizer:
"""
def __init__(
self, mode: SummarizationMode, summarizer_agent: VoiceSleeptimeAgent, message_buffer_limit: int = 10, message_buffer_min: int = 3
self,
mode: SummarizationMode,
summarizer_agent: Optional["VoiceSleeptimeAgent"] = None,
message_buffer_limit: int = 10,
message_buffer_min: int = 3,
):
self.mode = mode
@@ -90,39 +93,42 @@ class Summarizer:
logger.info("Nothing to evict, returning in context messages as is.")
return all_in_context_messages, False
evicted_messages = all_in_context_messages[1:target_trim_index]
if self.summarizer_agent:
# Only invoke if summarizer agent is passed in
# Format
formatted_evicted_messages = format_transcript(evicted_messages)
formatted_in_context_messages = format_transcript(updated_in_context_messages)
evicted_messages = all_in_context_messages[1:target_trim_index]
# Update the message transcript of the memory agent
self.summarizer_agent.update_message_transcript(message_transcripts=formatted_evicted_messages + formatted_in_context_messages)
# Format
formatted_evicted_messages = format_transcript(evicted_messages)
formatted_in_context_messages = format_transcript(updated_in_context_messages)
# Add line numbers to the formatted messages
line_number = 0
for i in range(len(formatted_evicted_messages)):
formatted_evicted_messages[i] = f"{line_number}. " + formatted_evicted_messages[i]
line_number += 1
for i in range(len(formatted_in_context_messages)):
formatted_in_context_messages[i] = f"{line_number}. " + formatted_in_context_messages[i]
line_number += 1
# TODO: This is hyperspecific to voice, generalize!
# Update the message transcript of the memory agent
self.summarizer_agent.update_message_transcript(message_transcripts=formatted_evicted_messages + formatted_in_context_messages)
evicted_messages_str = "\n".join(formatted_evicted_messages)
in_context_messages_str = "\n".join(formatted_in_context_messages)
summary_request_text = f"""Youre a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
# Add line numbers to the formatted messages
line_number = 0
for i in range(len(formatted_evicted_messages)):
formatted_evicted_messages[i] = f"{line_number}. " + formatted_evicted_messages[i]
line_number += 1
for i in range(len(formatted_in_context_messages)):
formatted_in_context_messages[i] = f"{line_number}. " + formatted_in_context_messages[i]
line_number += 1
(Older) Evicted Messages:\n
{evicted_messages_str}\n
evicted_messages_str = "\n".join(formatted_evicted_messages)
in_context_messages_str = "\n".join(formatted_in_context_messages)
summary_request_text = f"""Youre a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
(Newer) In-Context Messages:\n
{in_context_messages_str}
"""
print(summary_request_text)
# Fire-and-forget the summarization task
self.fire_and_forget(
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
)
(Older) Evicted Messages:\n
{evicted_messages_str}\n
(Newer) In-Context Messages:\n
{in_context_messages_str}
"""
# Fire-and-forget the summarization task
self.fire_and_forget(
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
)
return [all_in_context_messages[0]] + updated_in_context_messages, True

View File

@@ -17,7 +17,7 @@ from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.group import GroupUpdate, ManagerType, VoiceSleeptimeManagerUpdate
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
@@ -29,6 +29,7 @@ from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.services.tool_manager import ToolManager
@@ -336,19 +337,17 @@ async def test_summarization(disable_e2b_api_key, voice_agent):
)
sleeptime_agent = agent_manager.create_agent(request, actor=actor)
async_client = AsyncOpenAI()
memory_agent = VoiceSleeptimeAgent(
agent_id=sleeptime_agent.id,
convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent
openai_client=async_client,
message_manager=MessageManager(),
agent_manager=agent_manager,
actor=actor,
block_manager=BlockManager(),
passage_manager=PassageManager(),
target_block_label="human",
message_transcripts=MESSAGE_TRANSCRIPTS,
)
memory_agent.update_message_transcript(MESSAGE_TRANSCRIPTS)
summarizer = Summarizer(
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
@@ -389,12 +388,15 @@ async def test_summarization(disable_e2b_api_key, voice_agent):
@pytest.mark.asyncio
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
async def test_voice_sleeptime_agent(disable_e2b_api_key, client, voice_agent):
"""Tests chat completion streaming using the Async OpenAI client."""
agent_manager = AgentManager()
user_manager = UserManager()
actor = user_manager.get_default_user()
finish_rethinking_memory_tool = client.tools.list(name="finish_rethinking_memory")[0]
store_memories_tool = client.tools.list(name="store_memories")[0]
rethink_user_memory_tool = client.tools.list(name="rethink_user_memory")[0]
request = CreateAgent(
name=voice_agent.name + "-sleeptime",
agent_type=AgentType.voice_sleeptime_agent,
@@ -408,50 +410,46 @@ async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
project_id=voice_agent.project_id,
tool_ids=[finish_rethinking_memory_tool.id, store_memories_tool.id, rethink_user_memory_tool.id],
)
sleeptime_agent = agent_manager.create_agent(request, actor=actor)
async_client = AsyncOpenAI()
memory_agent = VoiceSleeptimeAgent(
agent_id=sleeptime_agent.id,
convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent
openai_client=async_client,
message_manager=MessageManager(),
agent_manager=agent_manager,
actor=actor,
block_manager=BlockManager(),
passage_manager=PassageManager(),
target_block_label="human",
message_transcripts=MESSAGE_TRANSCRIPTS,
)
memory_agent.update_message_transcript(MESSAGE_TRANSCRIPTS)
results = await memory_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=SUMMARY_REQ_TEXT)])])
messages = results.messages
# --- Basic structural check ---
assert isinstance(messages, list)
assert len(messages) >= 5, "Expected at least 5 messages in the sequence"
# collect the names of every tool call
seen_tool_calls = set()
# --- Message 0: initial UserMessage ---
assert isinstance(messages[0], UserMessage), "First message should be a UserMessage"
for idx, msg in enumerate(messages):
# 1) Print whatever “content” this message carries
if hasattr(msg, "content") and msg.content is not None:
print(f"Message {idx} content:\n{msg.content}\n")
# 2) If its a ToolCallMessage, also grab its name and print the raw args
elif isinstance(msg, ToolCallMessage):
name = msg.tool_call.name
args = msg.tool_call.arguments
seen_tool_calls.add(name)
print(f"Message {idx} TOOL CALL: {name}\nArguments:\n{args}\n")
# 3) Otherwise just dump the repr
else:
print(f"Message {idx} repr:\n{msg!r}\n")
# --- Message 1: store_memories ToolCall ---
assert isinstance(messages[1], ToolCallMessage), "Second message should be ToolCallMessage"
assert messages[1].name == "store_memories", "Expected store_memories tool call"
# --- Message 2: store_memories ToolReturn ---
assert isinstance(messages[2], ToolReturnMessage), "Third message should be ToolReturnMessage"
assert messages[2].name == "store_memories", "Expected store_memories tool return"
assert messages[2].status == "success", "store_memories tool return should be successful"
# --- Message 3: rethink_user_memory ToolCall ---
assert isinstance(messages[3], ToolCallMessage), "Fourth message should be ToolCallMessage"
assert messages[3].name == "rethink_user_memory", "Expected rethink_user_memory tool call"
# --- Message 4: rethink_user_memory ToolReturn ---
assert isinstance(messages[4], ToolReturnMessage), "Fifth message should be ToolReturnMessage"
assert messages[4].name == "rethink_user_memory", "Expected rethink_user_memory tool return"
assert messages[4].status == "success", "rethink_user_memory tool return should be successful"
# now verify we saw each of the three calls at least once
expected = {"store_memories", "rethink_user_memory", "finish_rethinking_memory"}
missing = expected - seen_tool_calls
assert not missing, f"Did not see calls to: {', '.join(missing)}"
@pytest.mark.asyncio

View File

@@ -1,9 +1,21 @@
import os
import threading
import pytest
from dotenv import load_dotenv
from letta_client import Letta
import letta.functions.function_sets.base as base_functions
from letta import LocalClient, create_client
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from tests.test_tool_schema_parsing_files.expected_base_tool_schemas import (
get_finish_rethinking_memory_schema,
get_rethink_user_memory_schema,
get_search_memory_schema,
get_store_memories_schema,
)
from tests.utils import wait_for_server
@pytest.fixture(scope="function")
@@ -15,6 +27,35 @@ def client():
yield client
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
wait_for_server(url)
return url
@pytest.fixture(scope="session")
def letta_client(server_url):
"""Creates a REST client for testing."""
client = Letta(base_url=server_url)
client.tools.upsert_base_tools()
return client
@pytest.fixture(scope="function")
def agent_obj(client: LocalClient):
"""Create a test agent that we can call functions on"""
@@ -98,3 +139,57 @@ def test_recall(client, agent_obj):
# Conversation search
result = base_functions.conversation_search(agent_obj, "banana")
assert keyword in result
def test_get_rethink_user_memory_parsing(letta_client):
tool = letta_client.tools.list(name="rethink_user_memory")[0]
json_schema = tool.json_schema
# Remove `request_heartbeat` from properties
json_schema["parameters"]["properties"].pop("request_heartbeat", None)
# Remove it from the required list if present
required = json_schema["parameters"].get("required", [])
if "request_heartbeat" in required:
required.remove("request_heartbeat")
assert json_schema == get_rethink_user_memory_schema()
def test_get_finish_rethinking_memory_parsing(letta_client):
tool = letta_client.tools.list(name="finish_rethinking_memory")[0]
json_schema = tool.json_schema
# Remove `request_heartbeat` from properties
json_schema["parameters"]["properties"].pop("request_heartbeat", None)
# Remove it from the required list if present
required = json_schema["parameters"].get("required", [])
if "request_heartbeat" in required:
required.remove("request_heartbeat")
assert json_schema == get_finish_rethinking_memory_schema()
def test_store_memories_parsing(letta_client):
tool = letta_client.tools.list(name="store_memories")[0]
json_schema = tool.json_schema
# Remove `request_heartbeat` from properties
json_schema["parameters"]["properties"].pop("request_heartbeat", None)
# Remove it from the required list if present
required = json_schema["parameters"].get("required", [])
if "request_heartbeat" in required:
required.remove("request_heartbeat")
assert json_schema == get_store_memories_schema()
def test_search_memory_parsing(letta_client):
tool = letta_client.tools.list(name="search_memory")[0]
json_schema = tool.json_schema
# Remove `request_heartbeat` from properties
json_schema["parameters"]["properties"].pop("request_heartbeat", None)
# Remove it from the required list if present
required = json_schema["parameters"].get("required", [])
if "request_heartbeat" in required:
required.remove("request_heartbeat")
assert json_schema == get_search_memory_schema()

View File

@@ -0,0 +1,95 @@
def get_rethink_user_memory_schema():
return {
"name": "rethink_user_memory",
"description": (
"Rewrite memory block for the main agent, new_memory should contain all current "
"information from the block that is not outdated or inconsistent, integrating any "
"new information, resulting in a new memory block that is organized, readable, and "
"comprehensive."
),
"parameters": {
"type": "object",
"properties": {
"new_memory": {
"type": "string",
"description": (
"The new memory with information integrated from the memory block. "
"If there is no new information, then this should be the same as the "
"content in the source block."
),
},
},
"required": ["new_memory"],
},
}
def get_finish_rethinking_memory_schema():
return {
"name": "finish_rethinking_memory",
"description": "This function is called when the agent is done rethinking the memory.",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
}
def get_store_memories_schema():
return {
"name": "store_memories",
"description": "Persist dialogue that is about to fall out of the agents context window.",
"parameters": {
"type": "object",
"properties": {
"chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"start_index": {"type": "integer", "description": "Zero-based index of the first evicted line in this chunk."},
"end_index": {"type": "integer", "description": "Zero-based index of the last evicted line (inclusive)."},
"context": {
"type": "string",
"description": "1-3 sentence paraphrase capturing key facts/details, user preferences, or goals that this chunk reveals—written for future retrieval.",
},
},
"required": ["start_index", "end_index", "context"],
},
"description": "Each chunk pinpoints a contiguous block of **evicted** lines and provides a short, forward-looking synopsis (`context`) that will be embedded for future semantic lookup.",
}
},
"required": ["chunks"],
},
}
def get_search_memory_schema():
return {
"name": "search_memory",
"description": "Look in long-term or earlier-conversation memory only when the user asks about something missing from the visible context. The users latest utterance is sent automatically as the main query.",
"parameters": {
"type": "object",
"properties": {
"convo_keyword_queries": {
"type": "array",
"items": {"type": "string"},
"description": (
"Extra keywords (e.g., order ID, place name). Use *null* if not appropriate for the latest user message."
),
},
"start_minutes_ago": {
"type": "integer",
"description": (
"Newer bound of the time window for results, specified in minutes ago. Use *null* if no lower time bound is needed."
),
},
"end_minutes_ago": {
"type": "integer",
"description": ("Older bound of the time window, in minutes ago. Use *null* if no upper bound is needed."),
},
},
"required": [],
},
}