fix: Fix voice agent and tests (#2305)

This commit is contained in:
Matthew Zhou
2025-05-21 11:05:54 -07:00
committed by GitHub
parent c634b3f6e3
commit 3f6a710894
4 changed files with 98 additions and 75 deletions

View File

@@ -299,7 +299,7 @@ class VoiceAgent(BaseAgent):
num_messages: int | None = None,
num_archival_memories: int | None = None,
) -> List[Message]:
return super()._rebuild_memory_async(
return await super()._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)

View File

@@ -1,10 +1,10 @@
import os
import threading
import subprocess
import sys
from unittest.mock import MagicMock
import pytest
from dotenv import load_dotenv
from letta_client import AsyncLetta
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionChunk
@@ -35,7 +35,7 @@ from letta.services.summarizer.summarizer import Summarizer
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import get_persona_text
from tests.utils import wait_for_server
from tests.utils import create_tool_from_func, wait_for_server
MESSAGE_TRANSCRIPTS = [
"user: Hey, Ive been thinking about planning a road trip up the California coast next month.",
@@ -92,17 +92,6 @@ Youre a memory-recall helper for an AI that can only keep the last 4 messages
# --- Server Management --- #
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
return server
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
@@ -111,31 +100,66 @@ def _run_server():
start_server(debug=True)
@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
"""
Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI,
so its event loop and DB pool stay completely isolated from pytest-asyncio.
"""
url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283")
# Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
wait_for_server(url) # Allow server startup time
# Build the command to launch uvicorn on your FastAPI app
cmd = [
sys.executable,
"-m",
"uvicorn",
"letta.server.rest_api.app:app",
"--host",
"127.0.0.1",
"--port",
"8283",
]
# If you need TLS or reload settings from start_server(), you can add
# "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well.
return url
server_proc = subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# wait until the HTTP port is accepting connections
wait_for_server(url)
yield url
# Teardown: kill the subprocess if we started it
server_proc.terminate()
server_proc.wait(timeout=10)
else:
yield url
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
actor = server.user_manager.get_user_or_default()
server.tool_manager.upsert_base_tools(actor=actor)
return server
# --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = AsyncLetta(base_url=server_url)
yield client
@pytest.fixture(scope="function")
async def roll_dice_tool(client):
@pytest.fixture
async def roll_dice_tool(server):
def roll_dice():
"""
Rolls a 6 sided die.
@@ -145,13 +169,13 @@ async def roll_dice_tool(client):
"""
return "Rolled a 10!"
tool = await client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=roll_dice), actor=actor)
yield tool
@pytest.fixture(scope="function")
async def weather_tool(client):
@pytest.fixture
async def weather_tool(server):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
@@ -176,22 +200,20 @@ async def weather_tool(client):
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
tool = await client.tools.upsert_from_function(func=get_weather)
# Yield the created tool
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
yield tool
@pytest.fixture(scope="function")
@pytest.fixture
def composio_gmail_get_profile_tool(default_user):
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
yield tool
@pytest.fixture(scope="function")
@pytest.fixture
def voice_agent(server, actor):
server.tool_manager.upsert_base_tools(actor=actor)
main_agent = server.create_agent(
request=CreateAgent(
agent_type=AgentType.voice_convo_agent,
@@ -268,9 +290,9 @@ def _assert_valid_chunk(chunk, idx, chunks):
# --- Tests --- #
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"])
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor):
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, server_url, group_id, actor):
request = _get_chat_request("How are you?")
server.tool_manager.upsert_base_tools(actor=actor)
@@ -303,10 +325,10 @@ async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, serv
print(chunk.choices[0].delta.content)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."])
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint):
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint, server_url):
"""Tests chat completion streaming using the Async OpenAI client."""
request = _get_chat_request(message)
@@ -318,9 +340,9 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en
print(chunk.choices[0].delta.content)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor):
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url):
server.group_manager.modify_group(
group_id=group_id,
group_update=GroupUpdate(
@@ -350,8 +372,8 @@ async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, g
print(chunk.choices[0].delta.content)
@pytest.mark.asyncio(loop_scope="session")
async def test_summarization(disable_e2b_api_key, voice_agent):
@pytest.mark.asyncio(loop_scope="module")
async def test_summarization(disable_e2b_api_key, voice_agent, server_url):
agent_manager = AgentManager()
user_manager = UserManager()
actor = user_manager.get_default_user()
@@ -422,8 +444,8 @@ async def test_summarization(disable_e2b_api_key, voice_agent):
summarizer.fire_and_forget.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
@pytest.mark.asyncio(loop_scope="module")
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent, server_url):
"""Tests chat completion streaming using the Async OpenAI client."""
agent_manager = AgentManager()
tool_manager = ToolManager()
@@ -488,8 +510,8 @@ async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
assert not missing, f"Did not see calls to: {', '.join(missing)}"
@pytest.mark.asyncio(loop_scope="session")
async def test_init_voice_convo_agent(voice_agent, server, actor):
@pytest.mark.asyncio(loop_scope="module")
async def test_init_voice_convo_agent(voice_agent, server, actor, server_url):
assert voice_agent.enable_sleeptime == True
main_agent_tools = [tool.name for tool in voice_agent.tools]

View File

@@ -1,5 +1,5 @@
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from typing import Tuple
from unittest.mock import AsyncMock, patch
import pytest
@@ -15,7 +15,6 @@ from anthropic.types.beta.messages import (
from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.config import LettaConfig
from letta.functions.functions import parse_source_code
from letta.helpers import ToolRulesSolver
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
@@ -25,10 +24,10 @@ from letta.schemas.job import BatchJob
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.message import MessageCreate
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import InitToolRule
from letta.server.db import db_context
from letta.server.server import SyncServer
from tests.utils import create_tool_from_func
# --------------------------------------------------------------------------- #
# Test Constants / Helpers
@@ -45,24 +44,6 @@ MODELS = {
EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
def create_tool_from_func(
func,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
):
source_code = parse_source_code(func)
source_type = "python"
if not tags:
tags = []
return Tool(
source_type=source_type,
source_code=source_code,
tags=tags,
description=description,
)
# --------------------------------------------------------------------------- #
# Test Fixtures
# --------------------------------------------------------------------------- #

View File

@@ -4,15 +4,17 @@ import string
import time
from datetime import datetime, timezone
from importlib import util
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Optional, Tuple
import requests
from letta.config import LettaConfig
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
from letta.schemas.enums import MessageRole
from letta.schemas.file import FileMetadata
from letta.schemas.message import Message
from letta.schemas.tool import Tool
from letta.settings import TestSettings
from .constants import TIMEOUT
@@ -199,3 +201,21 @@ def wait_for_server(url, timeout=30, interval=0.5):
def random_string(length: int) -> str:
return "".join(random.choices(string.ascii_letters + string.digits, k=length))
def create_tool_from_func(
func,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
):
source_code = parse_source_code(func)
source_type = "python"
if not tags:
tags = []
return Tool(
source_type=source_type,
source_code=source_code,
tags=tags,
description=description,
)