feat: Move sleeptime voice agent to new agent loop (#1979)

This commit is contained in:
Matthew Zhou
2025-05-01 20:48:33 -07:00
committed by GitHub
parent 9722596a74
commit daa30d6662
10 changed files with 419 additions and 420 deletions

View File

@@ -1,332 +1,138 @@
import json
import xml.etree.ElementTree as ET
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
from typing import AsyncGenerator, List, Tuple, Union
import openai
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _create_letta_response, serialize_message_history
from letta.agents.letta_agent import LettaAgent
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.block import BlockUpdate
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, UserMessage
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.message import MessageCreate
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
from letta.schemas.user import User
from letta.server.rest_api.utils import convert_in_context_letta_messages_to_openai, create_input_messages
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.system import package_function_response
from letta.services.passage_manager import PassageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.tracing import trace_method
# TODO: Move this to the new Letta Agent loop
class VoiceSleeptimeAgent(BaseAgent):
class VoiceSleeptimeAgent(LettaAgent):
"""
A stateless agent that helps with offline memory computations.
A special variant of the LettaAgent that helps with offline memory computations specifically for voice.
"""
def __init__(
self,
agent_id: str,
convo_agent_state: AgentState,
openai_client: openai.AsyncClient,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
target_block_label: str,
message_transcripts: List[str],
actor: User,
):
super().__init__(
agent_id=agent_id,
openai_client=openai_client,
message_manager=message_manager,
agent_manager=agent_manager,
block_manager=block_manager,
passage_manager=passage_manager,
actor=actor,
)
self.convo_agent_state = convo_agent_state
self.block_manager = block_manager
self.target_block_label = target_block_label
self.message_transcripts = message_transcripts
self.message_transcripts = []
self.summarizer = Summarizer(
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
summarizer_agent=None,
message_buffer_limit=20,
message_buffer_min=10,
)
def update_message_transcript(self, message_transcripts: List[str]):
self.message_transcripts = message_transcripts
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
async def step(self, input_messages: List[MessageCreate], max_steps: int = 20) -> LettaResponse:
"""
Process the user's input message, allowing the model to call memory-related tools
until it decides to stop and provide a final response.
"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
in_context_messages = create_input_messages(input_messages=input_messages, agent_id=self.agent_id, actor=self.actor)
openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True)
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
# 1. Store memories
request = self._build_openai_request(openai_messages, agent_state, tools=self._build_store_memory_tool_schemas())
# Add tool rules to the agent_state specifically for this type of agent
agent_state.tool_rules = [
InitToolRule(tool_name="store_memories"),
ChildToolRule(tool_name="store_memories", children=["rethink_user_memory"]),
ContinueToolRule(tool_name="rethink_user_memory"),
TerminalToolRule(tool_name="finish_rethinking_memory"),
]
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
assistant_message = chat_completion.choices[0].message
# Process tool calls
tool_call = assistant_message.tool_calls[0]
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
if function_name == "store_memories":
print("Called store_memories")
print(function_args)
chunks = function_args.get("chunks", [])
results = [self.store_memory(agent_state=self.convo_agent_state, **chunk_args) for chunk_args in chunks]
aggregated_result = next((res for res, _ in results if res is not None), None)
aggregated_success = all(success for _, success in results)
else:
raise ValueError("Error: Unknown tool function '{function_name}'")
assistant_message = {
"role": "assistant",
"content": assistant_message.content,
"tool_calls": [
{
"id": tool_call.id,
"type": "function",
"function": {"name": function_name, "arguments": tool_call.function.arguments},
}
],
}
openai_messages.append(assistant_message)
in_context_messages.append(
Message.dict_to_message(
agent_id=self.agent_id,
openai_message_dict=assistant_message,
model=agent_state.llm_config.model,
name=function_name,
)
# Summarize
current_in_context_messages, new_in_context_messages = await super()._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
)
tool_call_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"content": package_function_response(was_success=aggregated_success, response_string=str(aggregated_result)),
}
openai_messages.append(tool_call_message)
in_context_messages.append(
Message.dict_to_message(
agent_id=self.agent_id,
openai_message_dict=tool_call_message,
model=agent_state.llm_config.model,
name=function_name,
tool_returns=[ToolReturn(status="success" if aggregated_success else "error")],
)
new_in_context_messages, updated = self.summarizer.summarize(
in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages
)
self.agent_manager.set_in_context_messages(
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
# 2. Execute rethink block memory loop
human_block_content = self.agent_manager.get_block_with_label(
agent_id=self.agent_id, block_label=self.target_block_label, actor=self.actor
)
rethink_command = f"""
Here is the current memory block created earlier:
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
### CURRENT MEMORY
{human_block_content}
### END CURRENT MEMORY
Please refine this block:
- Merge in any new facts and remove outdated or contradictory details.
- Organize related information together (e.g., preferences, background, ongoing goals).
- Add any light, supportable inferences that deepen understanding—but do not invent unsupported details.
Use `rethink_user_memory(new_memory)` as many times as you need to iteratively improve the text. When its fully polished and complete, call `finish_rethinking_memory()`.
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
"""
rethink_command = UserMessage(content=rethink_command)
openai_messages.append(rethink_command.model_dump())
Executes a tool and returns (result, success_flag).
"""
# Special memory case
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
if not target_tool:
return f"Tool not found: {tool_name}", False
for _ in range(max_steps):
request = self._build_openai_request(openai_messages, agent_state, tools=self._build_sleeptime_tools())
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
assistant_message = chat_completion.choices[0].message
try:
if target_tool.name == "rethink_user_memory" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
return self.rethink_user_memory(agent_state=agent_state, **tool_args)
elif target_tool.name == "finish_rethinking_memory" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
return "", True
elif target_tool.name == "store_memories" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
chunks = tool_args.get("chunks", [])
results = [self.store_memory(agent_state=self.convo_agent_state, **chunk_args) for chunk_args in chunks]
# Process tool calls
tool_call = assistant_message.tool_calls[0]
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
aggregated_result = next((res for res, _ in results if res is not None), None)
aggregated_success = all(success for _, success in results)
if function_name == "rethink_user_memory":
print("Called rethink_user_memory")
print(function_args)
result, success = self.rethink_user_memory(agent_state=agent_state, **function_args)
elif function_name == "finish_rethinking_memory":
print("Called finish_rethinking_memory")
result, success = None, True
break
return aggregated_result, aggregated_success # Note that here we store to the convo agent's archival memory
else:
print(f"Error: Unknown tool function '{function_name}'")
raise ValueError(f"Error: Unknown tool function '{function_name}'", False)
assistant_message = {
"role": "assistant",
"content": assistant_message.content,
"tool_calls": [
{
"id": tool_call.id,
"type": "function",
"function": {"name": function_name, "arguments": tool_call.function.arguments},
}
],
}
openai_messages.append(assistant_message)
in_context_messages.append(
Message.dict_to_message(
agent_id=self.agent_id,
openai_message_dict=assistant_message,
model=agent_state.llm_config.model,
name=function_name,
)
)
tool_call_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"content": package_function_response(was_success=success, response_string=str(result)),
}
openai_messages.append(tool_call_message)
in_context_messages.append(
Message.dict_to_message(
agent_id=self.agent_id,
openai_message_dict=tool_call_message,
model=agent_state.llm_config.model,
name=function_name,
tool_returns=[ToolReturn(status="success" if success else "error")],
)
)
result = f"Voice sleeptime agent tried invoking invalid tool with type {target_tool.tool_type}: {target_tool}"
return result, False
except Exception as e:
return f"Failed to call tool. Error: {e}", False
# Actually save the memory:
target_block = agent_state.memory.get_block(self.target_block_label)
self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=target_block.value), actor=self.actor)
self.message_manager.create_many_messages(pydantic_msgs=in_context_messages, actor=self.actor)
return LettaResponse(messages=[msg for m in in_context_messages for msg in m.to_letta_messages()], usage=LettaUsageStatistics())
def _format_messages_llm_friendly(self):
messages = self.message_manager.list_messages_for_agent(agent_id=self.agent_id, actor=self.actor)
llm_friendly_messages = [f"{m.role}: {m.content[0].text}" for m in messages if m.content and isinstance(m.content[0], TextContent)]
return "\n".join(llm_friendly_messages)
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState, tools: List[Tool]) -> ChatCompletionRequest:
openai_request = ChatCompletionRequest(
model=agent_state.llm_config.model, # TODO: Separate config for summarizer?
messages=openai_messages,
tools=tools,
tool_choice="required",
user=self.actor.id,
max_completion_tokens=agent_state.llm_config.max_tokens,
temperature=agent_state.llm_config.temperature,
stream=False,
)
return openai_request
def _build_store_memory_tool_schemas(self) -> List[Tool]:
"""
Build the schemas for the three memory-related tools.
"""
tools = [
Tool(
type="function",
function={
"name": "store_memories",
"description": "Archive coherent chunks of dialogue that will be evicted, preserving raw lines and a brief contextual description.",
"parameters": {
"type": "object",
"properties": {
"chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"start_index": {"type": "integer", "description": "Index of first line in original history."},
"end_index": {"type": "integer", "description": "Index of last line in original history."},
"context": {
"type": "string",
"description": "A high-level description providing context for why this chunk matters.",
},
},
"required": ["start_index", "end_index", "context"],
},
}
},
"required": ["chunks"],
"additionalProperties": False,
},
},
),
]
return tools
def _build_sleeptime_tools(self) -> List[Tool]:
tools = [
Tool(
type="function",
function={
"name": "rethink_user_memory",
"description": (
"Rewrite memory block for the main agent, new_memory should contain all current "
"information from the block that is not outdated or inconsistent, integrating any "
"new information, resulting in a new memory block that is organized, readable, and "
"comprehensive."
),
"parameters": {
"type": "object",
"properties": {
"new_memory": {
"type": "string",
"description": (
"The new memory with information integrated from the memory block. "
"If there is no new information, then this should be the same as the "
"content in the source block."
),
},
},
"required": ["new_memory"],
"additionalProperties": False,
},
},
),
Tool(
type="function",
function={
"name": "finish_rethinking_memory",
"description": ("This function is called when the agent is done rethinking the memory."),
"parameters": {
"type": "object",
"properties": {},
"required": [],
"additionalProperties": False,
},
},
),
]
return tools
def rethink_user_memory(self, new_memory: str, agent_state: AgentState) -> Tuple[Optional[str], bool]:
def rethink_user_memory(self, new_memory: str, agent_state: AgentState) -> Tuple[str, bool]:
if agent_state.memory.get_block(self.target_block_label) is None:
agent_state.memory.create_block(label=self.target_block_label, value=new_memory)
agent_state.memory.update_block_value(label=self.target_block_label, value=new_memory)
return None, True
def store_memory(self, start_index: int, end_index: int, context: str, agent_state: AgentState) -> Tuple[Optional[str], bool]:
target_block = agent_state.memory.get_block(self.target_block_label)
self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=target_block.value), actor=self.actor)
return "", True
def store_memory(self, start_index: int, end_index: int, context: str, agent_state: AgentState) -> Tuple[str, bool]:
"""
Store a memory.
"""
try:
messages = self.message_transcripts[start_index : end_index + 1]
memory = self.serialize(messages, context)
memory = serialize_message_history(messages, context)
self.agent_manager.passage_manager.insert_passage(
agent_state=agent_state,
agent_id=agent_state.id,
@@ -335,63 +141,12 @@ Use `rethink_user_memory(new_memory)` as many times as you need to iteratively i
)
self.agent_manager.rebuild_system_prompt(agent_id=agent_state.id, actor=self.actor, force=True)
return None, True
return "", True
except Exception as e:
return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}", False
def serialize(self, messages: List[str], context: str) -> str:
"""
Produce an XML document like:
<memory>
<messages>
<message>…</message>
<message>…</message>
</messages>
<context>…</context>
</memory>
"""
root = ET.Element("memory")
msgs_el = ET.SubElement(root, "messages")
for msg in messages:
m = ET.SubElement(msgs_el, "message")
m.text = msg
sum_el = ET.SubElement(root, "context")
sum_el.text = context
# ET.tostring will escape reserved chars for you
return ET.tostring(root, encoding="unicode")
def deserialize(self, xml_str: str) -> Tuple[List[str], str]:
"""
Parse the XML back into (messages, context). Raises ValueError if tags are missing.
"""
try:
root = ET.fromstring(xml_str)
except ET.ParseError as e:
raise ValueError(f"Invalid XML: {e}")
msgs_el = root.find("messages")
if msgs_el is None:
raise ValueError("Missing <messages> section")
messages = []
for m in msgs_el.findall("message"):
# .text may be None if empty, so coerce to empty string
messages.append(m.text or "")
sum_el = root.find("context")
if sum_el is None:
raise ValueError("Missing <context> section")
context = sum_el.text or ""
return messages, context
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = False
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
"""
This agent is synchronous-only. If called in an async context, raise an error.