feat: Add Letta voice tools (#1922)

This commit is contained in:
Matthew Zhou
2025-04-29 10:46:29 -07:00
committed by GitHub
parent 97a0ccf682
commit 19bd790c58
14 changed files with 205 additions and 48 deletions

View File

@@ -74,8 +74,8 @@ class EphemeralMemoryAgent(BaseAgent):
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
if function_name == "store_memory":
print("Called store_memory")
if function_name == "store_memories":
print("Called store_memories")
print(function_args)
for chunk_args in function_args.get("chunks"):
self.store_memory(agent_state=agent_state, **chunk_args)
@@ -115,7 +115,7 @@ Please refine this block:
- Organize related information together (e.g., preferences, background, ongoing goals).
- Add any light, supportable inferences that deepen understanding—but do not invent unsupported details.
Use `rethink_memory(new_memory)` as many times as you need to iteratively improve the text. When its fully polished and complete, call `finish_rethinking_memory()`.
Use `rethink_user_memor(new_memory)` as many times as you need to iteratively improve the text. When its fully polished and complete, call `finish_rethinking_memory()`.
"""
rethink_command = UserMessage(content=rethink_command)
openai_messages.append(rethink_command.model_dump())
@@ -132,10 +132,10 @@ Use `rethink_memory(new_memory)` as many times as you need to iteratively improv
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
if function_name == "rethink_memory":
print("Called rethink_memory")
if function_name == "rethink_user_memor":
print("Called rethink_user_memor")
print(function_args)
result = self.rethink_memory(agent_state=agent_state, **function_args)
result = self.rethink_user_memor(agent_state=agent_state, **function_args)
elif function_name == "finish_rethinking_memory":
print("Called finish_rethinking_memory")
break
@@ -192,7 +192,7 @@ Use `rethink_memory(new_memory)` as many times as you need to iteratively improv
Tool(
type="function",
function={
"name": "store_memory",
"name": "store_memories",
"description": "Archive coherent chunks of dialogue that will be evicted, preserving raw lines and a brief contextual description.",
"parameters": {
"type": "object",
@@ -227,7 +227,7 @@ Use `rethink_memory(new_memory)` as many times as you need to iteratively improv
Tool(
type="function",
function={
"name": "rethink_memory",
"name": "rethink_user_memory",
"description": (
"Rewrite memory block for the main agent, new_memory should contain all current "
"information from the block that is not outdated or inconsistent, integrating any "
@@ -268,7 +268,7 @@ Use `rethink_memory(new_memory)` as many times as you need to iteratively improv
return tools
def rethink_memory(self, new_memory: str, agent_state: AgentState) -> str:
def rethink_user_memory(self, new_memory: str, agent_state: AgentState) -> str:
if agent_state.memory.get_block(self.target_block_label) is None:
agent_state.memory.create_block(label=self.target_block_label, value=new_memory)
@@ -365,7 +365,7 @@ When given a full transcript with lines marked (Older) or (Newer), you should:
- end_index: the last lines index
- context: a blurb explaining why this chunk matters
Return exactly one JSON tool call to `store_memory`, consider this miniature example:
Return exactly one JSON tool call to `store_memories`, consider this miniature example:
---
@@ -385,7 +385,7 @@ Example output:
```json
{
"name": "store_memory",
"name": "store_memories",
"arguments": {
"chunks": [
{
@@ -410,7 +410,7 @@ SYSTEM
You are a Memory-Updater agent. Your job is to iteratively refine the given memory block until its concise, organized, and complete.
Instructions:
- Call `rethink_memory(new_memory: string)` as many times as you like. Each call should submit a fully revised version of the block so far.
- Call `rethink_user_memor(new_memory: string)` as many times as you like. Each call should submit a fully revised version of the block so far.
- When youre fully satisfied, call `finish_rethinking_memory()`.
- Dont output anything else—only the JSON for these tool calls.

View File

@@ -18,6 +18,8 @@ MCP_TOOL_TAG_NAME_PREFIX = "mcp" # full format, mcp:server_name
LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base"
LETTA_MULTI_AGENT_TOOL_MODULE_NAME = "letta.functions.function_sets.multi_agent"
LETTA_VOICE_TOOL_MODULE_NAME = "letta.functions.function_sets.voice"
# String in the error message for when the context window is too large
# Example full message:
@@ -67,10 +69,20 @@ BASE_SLEEPTIME_TOOLS = [
# "archival_memory_search",
# "conversation_search",
]
# Base tools for the voice agent
BASE_VOICE_SLEEPTIME_CHAT_TOOLS = ["search_memory"]
# Base memory tools for sleeptime agent
BASE_VOICE_SLEEPTIME_TOOLS = [
"store_memories",
"rethink_user_memory",
"finish_rethinking_memory",
]
# Multi agent tools
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
# Set of all built-in Letta tools
LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS)
LETTA_TOOL_SET = set(
BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS
)
# The name of the tool used to send message to the user
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)

View File

@@ -186,16 +186,6 @@ def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_labe
return None
def finish_rethinking_memory(agent_state: "AgentState") -> None: # type: ignore
"""
This function is called when the agent is done rethinking the memory.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
return None
## Attempted v2 of sleep-time function set, meant to work better across all types
SNIPPET_LINES: int = 4

View File

@@ -0,0 +1,92 @@
## Voice chat + sleeptime tools
from typing import List, Optional
from pydantic import BaseModel, Field
def rethink_user_memory(agent_state: "AgentState", new_memory: str) -> None:
"""
Rewrite memory block for the main agent, new_memory should contain all current
information from the block that is not outdated or inconsistent, integrating any
new information, resulting in a new memory block that is organized, readable, and
comprehensive.
Args:
new_memory (str): The new memory with information integrated from the memory block.
If there is no new information, then this should be the same as
the content in the source block.
Returns:
None: None is always returned as this function does not produce a response.
"""
# This is implemented directly in the agent loop
return None
def finish_rethinking_memory(agent_state: "AgentState") -> None: # type: ignore
"""
This function is called when the agent is done rethinking the memory.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
return None
class MemoryChunk(BaseModel):
start_index: int = Field(..., description="Index of the first line in the original conversation history.")
end_index: int = Field(..., description="Index of the last line in the original conversation history.")
context: str = Field(..., description="A concise, high-level note explaining why this chunk matters.")
def store_memories(agent_state: "AgentState", chunks: List[MemoryChunk]) -> None:
"""
Archive coherent chunks of dialogue that will be evicted, preserving raw lines
and a brief contextual description.
Args:
agent_state (AgentState):
The agents current memory state, exposing both its in-session history
and the archival memory API.
chunks (List[MemoryChunk]):
A list of MemoryChunk models, each representing a segment to archive:
• start_index (int): Index of the first line in the original history.
• end_index (int): Index of the last line in the original history.
• context (str): A concise, high-level description of why this chunk
matters and what it contains.
Returns:
None
"""
# This is implemented directly in the agent loop
return None
def search_memory(
agent_state: "AgentState",
convo_keyword_queries: Optional[List[str]],
start_minutes_ago: Optional[int],
end_minutes_ago: Optional[int],
) -> Optional[str]:
"""
Look in long-term or earlier-conversation memory only when the user asks about
something missing from the visible context. The users latest utterance is sent
automatically as the main query.
Args:
agent_state (AgentState): The current state of the agent, including its
memory stores and context.
convo_keyword_queries (Optional[List[str]]): Extra keywords or identifiers
(e.g., order ID, place name) to refine the search when the request is vague.
Set to None if the users utterance is already specific.
start_minutes_ago (Optional[int]): Newer bound of the time window for results,
specified in minutes ago. Set to None if no lower time bound is needed.
end_minutes_ago (Optional[int]): Older bound of the time window for results,
specified in minutes ago. Set to None if no upper time bound is needed.
Returns:
Optional[str]: A formatted string of matching memory entries, or None if no
relevant memories are found.
"""
# This is implemented directly in the agent loop
return None

View File

@@ -78,9 +78,7 @@ def {func_name}(**kwargs):
return func_name, wrapper_function_str.strip()
def execute_composio_action(
action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None
) -> tuple[str, str]:
def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any:
import os
from composio.exceptions import (
@@ -110,10 +108,10 @@ def execute_composio_action(
except ComposioSDKError as e:
raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e))
if response["error"]:
if "error" in response:
raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"]))
return response["data"]
return response.get("data")
def generate_langchain_tool_wrapper(

View File

@@ -7,6 +7,7 @@ class ToolType(str, Enum):
LETTA_MEMORY_CORE = "letta_memory_core"
LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core"
LETTA_SLEEPTIME_CORE = "letta_sleeptime_core"
LETTA_VOICE_SLEEPTIME_CORE = "letta_voice_sleeptime_core"
EXTERNAL_COMPOSIO = "external_composio"
EXTERNAL_LANGCHAIN = "external_langchain"
# TODO is "external" the right name here? Since as of now, MCP is local / doesn't support remote?

View File

@@ -29,6 +29,8 @@ class AgentType(str, Enum):
memgpt_agent = "memgpt_agent"
split_thread_agent = "split_thread_agent"
sleeptime_agent = "sleeptime_agent"
voice_convo_agent = "voice_convo_agent"
voice_sleeptime_agent = "voice_sleeptime_agent"
class AgentState(OrmMetadataBase, validate_assignment=True):

View File

@@ -83,6 +83,16 @@ class SleeptimeManagerUpdate(ManagerConfig):
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
class VoiceSleeptimeManager(ManagerConfig):
manager_type: Literal[ManagerType.sleeptime] = Field(ManagerType.sleeptime, description="")
manager_agent_id: str = Field(..., description="")
class VoiceSleeptimeManagerUpdate(ManagerConfig):
manager_type: Literal[ManagerType.sleeptime] = Field(ManagerType.sleeptime, description="")
manager_agent_id: Optional[str] = Field(None, description="")
# class SwarmGroup(ManagerConfig):
# manager_type: Literal[ManagerType.swarm] = Field(ManagerType.swarm, description="")

View File

@@ -7,6 +7,7 @@ from letta.constants import (
FUNCTION_RETURN_CHAR_LIMIT,
LETTA_CORE_TOOL_MODULE_NAME,
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
LETTA_VOICE_TOOL_MODULE_NAME,
MCP_TOOL_TAG_NAME_PREFIX,
)
from letta.functions.ast_parsers import get_function_name_and_description
@@ -98,15 +99,15 @@ class Tool(BaseTool):
except Exception as e:
error_msg = f"Failed to derive json schema for tool with id={self.id} name={self.name}. Error: {str(e)}"
logger.error(error_msg)
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}:
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE, ToolType.LETTA_SLEEPTIME_CORE}:
# If it's letta core tool, we generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
elif self.tool_type in {ToolType.LETTA_MULTI_AGENT_CORE}:
# If it's letta multi-agent tool, we also generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name=self.name)
elif self.tool_type in {ToolType.LETTA_SLEEPTIME_CORE}:
# If it's letta sleeptime core tool, we generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
elif self.tool_type in {ToolType.LETTA_VOICE_SLEEPTIME_CORE}:
# If it's letta voice tool, we generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_VOICE_TOOL_MODULE_NAME, function_name=self.name)
# At this point, we need to validate that at least json_schema is populated
if not self.json_schema:

View File

@@ -44,7 +44,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.group import GroupCreate, SleeptimeManager
from letta.schemas.group import GroupCreate, SleeptimeManager, VoiceSleeptimeManager
from letta.schemas.job import Job, JobUpdate
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage
from letta.schemas.letta_message_content import TextContent
@@ -769,7 +769,10 @@ class SyncServer(Server):
log_event(name="end create_agent db")
if request.enable_sleeptime:
main_agent = self.create_sleeptime_agent(main_agent=main_agent, actor=actor)
if request.agent_type == AgentType.voice_convo_agent:
main_agent = self.create_voice_sleeptime_agent(main_agent=main_agent, actor=actor)
else:
main_agent = self.create_sleeptime_agent(main_agent=main_agent, actor=actor)
return main_agent
@@ -788,7 +791,10 @@ class SyncServer(Server):
if request.enable_sleeptime:
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent.multi_agent_group is None:
self.create_sleeptime_agent(main_agent=agent, actor=actor)
if agent.agent_type == AgentType.voice_convo_agent:
self.create_voice_sleeptime_agent(main_agent=agent, actor=actor)
else:
self.create_sleeptime_agent(main_agent=agent, actor=actor)
return self.agent_manager.update_agent(
agent_id=agent_id,
@@ -828,6 +834,38 @@ class SyncServer(Server):
)
return self.agent_manager.get_agent_by_id(agent_id=main_agent.id, actor=actor)
def create_voice_sleeptime_agent(self, main_agent: AgentState, actor: User) -> AgentState:
# TODO: Inject system
request = CreateAgent(
name=main_agent.name + "-sleeptime",
agent_type=AgentType.voice_sleeptime_agent,
block_ids=[block.id for block in main_agent.memory.blocks],
memory_blocks=[
CreateBlock(
label="memory_persona",
value=get_persona_text("sleeptime_memory_persona"),
),
],
llm_config=main_agent.llm_config,
embedding_config=main_agent.embedding_config,
project_id=main_agent.project_id,
)
voice_sleeptime_agent = self.agent_manager.create_agent(
agent_create=request,
actor=actor,
)
self.group_manager.create_group(
group=GroupCreate(
description="",
agent_ids=[voice_sleeptime_agent.id],
manager_config=VoiceSleeptimeManager(
manager_agent_id=main_agent.id,
),
),
actor=actor,
)
return self.agent_manager.get_agent_by_id(agent_id=main_agent.id, actor=actor)
# convert name->id
# TODO: These can be moved to agent_manager

View File

@@ -7,6 +7,8 @@ from letta.constants import (
BASE_MEMORY_TOOLS,
BASE_SLEEPTIME_TOOLS,
BASE_TOOLS,
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
BASE_VOICE_SLEEPTIME_TOOLS,
LETTA_TOOL_SET,
MCP_TOOL_TAG_NAME_PREFIX,
MULTI_AGENT_TOOLS,
@@ -190,7 +192,7 @@ class ToolManager:
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
"""Add default tools in base.py and multi_agent.py"""
functions_to_schema = {}
module_names = ["base", "multi_agent"]
module_names = ["base", "multi_agent", "voice"]
for module_name in module_names:
full_module_name = f"letta.functions.function_sets.{module_name}"
@@ -223,9 +225,12 @@ class ToolManager:
elif name in BASE_SLEEPTIME_TOOLS:
tool_type = ToolType.LETTA_SLEEPTIME_CORE
tags = [tool_type.value]
elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS:
tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE
tags = [tool_type.value]
else:
raise ValueError(
f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS}"
f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}"
)
# create to tool

View File

@@ -1,6 +1,5 @@
import os
import threading
import time
import uuid
import pytest
@@ -10,6 +9,7 @@ from letta_client.core.api_error import ApiError
from sqlalchemy import delete
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
from tests.utils import wait_for_server
# Constants
SERVER_PORT = 8283
@@ -41,7 +41,7 @@ def client(request):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
wait_for_server(server_url)
print("Running client tests with server:", server_url)
# create the Letta client

View File

@@ -11,7 +11,7 @@ from sqlalchemy import delete
from letta import create_client
from letta.client.client import LocalClient, RESTClient
from letta.constants import BASE_MEMORY_TOOLS, BASE_SLEEPTIME_TOOLS, BASE_TOOLS, DEFAULT_PRESET, MULTI_AGENT_TOOLS
from letta.constants import DEFAULT_PRESET
from letta.helpers.datetime_helpers import get_utc_time
from letta.orm import FileMetadata, Source
from letta.schemas.agent import AgentState
@@ -344,13 +344,6 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
assert all(visited_ids.values())
def test_list_tools(client: Union[LocalClient, RESTClient]):
tools = client.upsert_base_tools()
tool_names = [t.name for t in tools]
expected = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS)
assert sorted(tool_names) == sorted(expected)
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
# clear sources
for source in client.list_sources():

View File

@@ -20,6 +20,8 @@ from letta.constants import (
BASE_MEMORY_TOOLS,
BASE_SLEEPTIME_TOOLS,
BASE_TOOLS,
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
BASE_VOICE_SLEEPTIME_TOOLS,
LETTA_TOOL_EXECUTION_DIR,
MCP_TOOL_TAG_NAME_PREFIX,
MULTI_AGENT_TOOLS,
@@ -2294,7 +2296,16 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
def test_upsert_base_tools(server: SyncServer, default_user):
tools = server.tool_manager.upsert_base_tools(actor=default_user)
expected_tool_names = sorted(set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS))
expected_tool_names = sorted(
set(
BASE_TOOLS
+ BASE_MEMORY_TOOLS
+ MULTI_AGENT_TOOLS
+ BASE_SLEEPTIME_TOOLS
+ BASE_VOICE_SLEEPTIME_TOOLS
+ BASE_VOICE_SLEEPTIME_CHAT_TOOLS
)
)
assert sorted([t.name for t in tools]) == expected_tool_names
# Call it again to make sure it doesn't create duplicates
@@ -2311,6 +2322,10 @@ def test_upsert_base_tools(server: SyncServer, default_user):
assert t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE
elif t.name in BASE_SLEEPTIME_TOOLS:
assert t.tool_type == ToolType.LETTA_SLEEPTIME_CORE
elif t.name in BASE_VOICE_SLEEPTIME_TOOLS:
assert t.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE
elif t.name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS:
assert t.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE
else:
pytest.fail(f"The tool name is unrecognized as a base tool: {t.name}")
assert t.source_code is None