fix: Fix voice agent and tests (#2305)
This commit is contained in:
@@ -299,7 +299,7 @@ class VoiceAgent(BaseAgent):
|
||||
num_messages: int | None = None,
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
return super()._rebuild_memory_async(
|
||||
return await super()._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import threading
|
||||
import subprocess
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
@@ -35,7 +35,7 @@ from letta.services.summarizer.summarizer import Summarizer
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.utils import get_persona_text
|
||||
from tests.utils import wait_for_server
|
||||
from tests.utils import create_tool_from_func, wait_for_server
|
||||
|
||||
MESSAGE_TRANSCRIPTS = [
|
||||
"user: Hey, I’ve been thinking about planning a road trip up the California coast next month.",
|
||||
@@ -92,17 +92,6 @@ You’re a memory-recall helper for an AI that can only keep the last 4 messages
|
||||
# --- Server Management --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
return server
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
@@ -111,31 +100,66 @@ def _run_server():
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
"""
|
||||
Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI,
|
||||
so its event loop and DB pool stay completely isolated from pytest-asyncio.
|
||||
"""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283")
|
||||
|
||||
# Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(url) # Allow server startup time
|
||||
# Build the command to launch uvicorn on your FastAPI app
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"letta.server.rest_api.app:app",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"8283",
|
||||
]
|
||||
# If you need TLS or reload settings from start_server(), you can add
|
||||
# "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well.
|
||||
|
||||
return url
|
||||
server_proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
# wait until the HTTP port is accepting connections
|
||||
wait_for_server(url)
|
||||
|
||||
yield url
|
||||
|
||||
# Teardown: kill the subprocess if we started it
|
||||
server_proc.terminate()
|
||||
server_proc.wait(timeout=10)
|
||||
else:
|
||||
yield url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
return server
|
||||
|
||||
|
||||
# --- Client Setup --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = AsyncLetta(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def roll_dice_tool(client):
|
||||
@pytest.fixture
|
||||
async def roll_dice_tool(server):
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
@@ -145,13 +169,13 @@ async def roll_dice_tool(client):
|
||||
"""
|
||||
return "Rolled a 10!"
|
||||
|
||||
tool = await client.tools.upsert_from_function(func=roll_dice)
|
||||
# Yield the created tool
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=roll_dice), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def weather_tool(client):
|
||||
@pytest.fixture
|
||||
async def weather_tool(server):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
@@ -176,22 +200,20 @@ async def weather_tool(client):
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
tool = await client.tools.upsert_from_function(func=get_weather)
|
||||
# Yield the created tool
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture
|
||||
def composio_gmail_get_profile_tool(default_user):
|
||||
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
||||
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture
|
||||
def voice_agent(server, actor):
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
main_agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
agent_type=AgentType.voice_convo_agent,
|
||||
@@ -268,9 +290,9 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
# --- Tests --- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"])
|
||||
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor):
|
||||
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, server_url, group_id, actor):
|
||||
request = _get_chat_request("How are you?")
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
@@ -303,10 +325,10 @@ async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, serv
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint):
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint, server_url):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
@@ -318,9 +340,9 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor):
|
||||
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url):
|
||||
server.group_manager.modify_group(
|
||||
group_id=group_id,
|
||||
group_update=GroupUpdate(
|
||||
@@ -350,8 +372,8 @@ async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, g
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_summarization(disable_e2b_api_key, voice_agent):
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_summarization(disable_e2b_api_key, voice_agent, server_url):
|
||||
agent_manager = AgentManager()
|
||||
user_manager = UserManager()
|
||||
actor = user_manager.get_default_user()
|
||||
@@ -422,8 +444,8 @@ async def test_summarization(disable_e2b_api_key, voice_agent):
|
||||
summarizer.fire_and_forget.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent, server_url):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
agent_manager = AgentManager()
|
||||
tool_manager = ToolManager()
|
||||
@@ -488,8 +510,8 @@ async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
|
||||
assert not missing, f"Did not see calls to: {', '.join(missing)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_init_voice_convo_agent(voice_agent, server, actor):
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_init_voice_convo_agent(voice_agent, server, actor, server_url):
|
||||
|
||||
assert voice_agent.enable_sleeptime == True
|
||||
main_agent_tools = [tool.name for tool in voice_agent.tools]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Tuple
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -15,7 +15,6 @@ from anthropic.types.beta.messages import (
|
||||
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.config import LettaConfig
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
@@ -25,10 +24,10 @@ from letta.schemas.job import BatchJob
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import LettaBatchRequest
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
from tests.utils import create_tool_from_func
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Test Constants / Helpers
|
||||
@@ -45,24 +44,6 @@ MODELS = {
|
||||
EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
|
||||
|
||||
|
||||
def create_tool_from_func(
|
||||
func,
|
||||
tags: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
if not tags:
|
||||
tags = []
|
||||
|
||||
return Tool(
|
||||
source_type=source_type,
|
||||
source_code=source_code,
|
||||
tags=tags,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Test Fixtures
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
@@ -4,15 +4,17 @@ import string
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from importlib import util
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_sources.connectors import DataConnector
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.settings import TestSettings
|
||||
|
||||
from .constants import TIMEOUT
|
||||
@@ -199,3 +201,21 @@ def wait_for_server(url, timeout=30, interval=0.5):
|
||||
|
||||
def random_string(length: int) -> str:
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=length))
|
||||
|
||||
|
||||
def create_tool_from_func(
|
||||
func,
|
||||
tags: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
if not tags:
|
||||
tags = []
|
||||
|
||||
return Tool(
|
||||
source_type=source_type,
|
||||
source_code=source_code,
|
||||
tags=tags,
|
||||
description=description,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user