diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 7f124038..5451dc6c 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -299,7 +299,7 @@ class VoiceAgent(BaseAgent): num_messages: int | None = None, num_archival_memories: int | None = None, ) -> List[Message]: - return super()._rebuild_memory_async( + return await super()._rebuild_memory_async( in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories ) diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index 835a09e6..b3fc86dc 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -1,10 +1,10 @@ import os -import threading +import subprocess +import sys from unittest.mock import MagicMock import pytest from dotenv import load_dotenv -from letta_client import AsyncLetta from openai import AsyncOpenAI from openai.types.chat import ChatCompletionChunk @@ -35,7 +35,7 @@ from letta.services.summarizer.summarizer import Summarizer from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.utils import get_persona_text -from tests.utils import wait_for_server +from tests.utils import create_tool_from_func, wait_for_server MESSAGE_TRANSCRIPTS = [ "user: Hey, I’ve been thinking about planning a road trip up the California coast next month.", @@ -92,17 +92,6 @@ You’re a memory-recall helper for an AI that can only keep the last 4 messages # --- Server Management --- # -@pytest.fixture(scope="module") -def server(): - config = LettaConfig.load() - print("CONFIG PATH", config.config_path) - - config.save() - - server = SyncServer() - return server - - def _run_server(): """Starts the Letta server in a background thread.""" load_dotenv() @@ -111,31 +100,66 @@ def _run_server(): start_server(debug=True) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server_url(): - """Ensures a server is running and returns its base URL.""" - url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + """ + Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI, + so its event loop and DB pool stay completely isolated from pytest-asyncio. + """ + url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283") + # Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=_run_server, daemon=True) - thread.start() - wait_for_server(url) # Allow server startup time + # Build the command to launch uvicorn on your FastAPI app + cmd = [ + sys.executable, + "-m", + "uvicorn", + "letta.server.rest_api.app:app", + "--host", + "127.0.0.1", + "--port", + "8283", + ] + # If you need TLS or reload settings from start_server(), you can add + # "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well. - return url + server_proc = subprocess.Popen( + cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + # wait until the HTTP port is accepting connections + wait_for_server(url) + + yield url + + # Teardown: kill the subprocess if we started it + server_proc.terminate() + server_proc.wait(timeout=10) + else: + yield url + + +@pytest.fixture(scope="module") +def server(): + config = LettaConfig.load() + print("CONFIG PATH", config.config_path) + + config.save() + + server = SyncServer() + actor = server.user_manager.get_user_or_default() + server.tool_manager.upsert_base_tools(actor=actor) + return server # --- Client Setup --- # -@pytest.fixture(scope="session") -def client(server_url): - """Creates a REST client for testing.""" - client = AsyncLetta(base_url=server_url) - yield client - - -@pytest.fixture(scope="function") -async def roll_dice_tool(client): +@pytest.fixture +async def roll_dice_tool(server): def roll_dice(): """ Rolls a 6 sided die. @@ -145,13 +169,13 @@ async def roll_dice_tool(client): """ return "Rolled a 10!" - tool = await client.tools.upsert_from_function(func=roll_dice) - # Yield the created tool + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=roll_dice), actor=actor) yield tool -@pytest.fixture(scope="function") -async def weather_tool(client): +@pytest.fixture +async def weather_tool(server): def get_weather(location: str) -> str: """ Fetches the current weather for a given location. @@ -176,22 +200,20 @@ async def weather_tool(client): else: raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") - tool = await client.tools.upsert_from_function(func=get_weather) - # Yield the created tool + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor) yield tool -@pytest.fixture(scope="function") +@pytest.fixture def composio_gmail_get_profile_tool(default_user): tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE") tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user) yield tool -@pytest.fixture(scope="function") +@pytest.fixture def voice_agent(server, actor): - server.tool_manager.upsert_base_tools(actor=actor) - main_agent = server.create_agent( request=CreateAgent( agent_type=AgentType.voice_convo_agent, @@ -268,9 +290,9 @@ def _assert_valid_chunk(chunk, idx, chunks): # --- Tests --- # -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"]) -async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor): +async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, server_url, group_id, actor): request = _get_chat_request("How are you?") server.tool_manager.upsert_base_tools(actor=actor) @@ -303,10 +325,10 @@ async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, serv print(chunk.choices[0].delta.content) -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."]) @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) -async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint): +async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint, server_url): """Tests chat completion streaming using the Async OpenAI client.""" request = _get_chat_request(message) @@ -318,9 +340,9 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en print(chunk.choices[0].delta.content) -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) -async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor): +async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url): server.group_manager.modify_group( group_id=group_id, group_update=GroupUpdate( @@ -350,8 +372,8 @@ async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, g print(chunk.choices[0].delta.content) -@pytest.mark.asyncio(loop_scope="session") -async def test_summarization(disable_e2b_api_key, voice_agent): +@pytest.mark.asyncio(loop_scope="module") +async def test_summarization(disable_e2b_api_key, voice_agent, server_url): agent_manager = AgentManager() user_manager = UserManager() actor = user_manager.get_default_user() @@ -422,8 +444,8 @@ async def test_summarization(disable_e2b_api_key, voice_agent): summarizer.fire_and_forget.assert_called_once() -@pytest.mark.asyncio(loop_scope="session") -async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent): +@pytest.mark.asyncio(loop_scope="module") +async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent, server_url): """Tests chat completion streaming using the Async OpenAI client.""" agent_manager = AgentManager() tool_manager = ToolManager() @@ -488,8 +510,8 @@ async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent): assert not missing, f"Did not see calls to: {', '.join(missing)}" -@pytest.mark.asyncio(loop_scope="session") -async def test_init_voice_convo_agent(voice_agent, server, actor): +@pytest.mark.asyncio(loop_scope="module") +async def test_init_voice_convo_agent(voice_agent, server, actor, server_url): assert voice_agent.enable_sleeptime == True main_agent_tools = [tool.name for tool in voice_agent.tools] diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 70005133..da2a6666 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import List, Optional, Tuple +from typing import Tuple from unittest.mock import AsyncMock, patch import pytest @@ -15,7 +15,6 @@ from anthropic.types.beta.messages import ( from letta.agents.letta_agent_batch import LettaAgentBatch from letta.config import LettaConfig -from letta.functions.functions import parse_source_code from letta.helpers import ToolRulesSolver from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base @@ -25,10 +24,10 @@ from letta.schemas.job import BatchJob from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_request import LettaBatchRequest from letta.schemas.message import MessageCreate -from letta.schemas.tool import Tool from letta.schemas.tool_rule import InitToolRule from letta.server.db import db_context from letta.server.server import SyncServer +from tests.utils import create_tool_from_func # --------------------------------------------------------------------------- # # Test Constants / Helpers @@ -45,24 +44,6 @@ MODELS = { EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"] -def create_tool_from_func( - func, - tags: Optional[List[str]] = None, - description: Optional[str] = None, -): - source_code = parse_source_code(func) - source_type = "python" - if not tags: - tags = [] - - return Tool( - source_type=source_type, - source_code=source_code, - tags=tags, - description=description, - ) - - # --------------------------------------------------------------------------- # # Test Fixtures # --------------------------------------------------------------------------- # diff --git a/tests/utils.py b/tests/utils.py index 65c3ee2f..37b8ed6a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,15 +4,17 @@ import string import time from datetime import datetime, timezone from importlib import util -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import requests from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector +from letta.functions.functions import parse_source_code from letta.schemas.enums import MessageRole from letta.schemas.file import FileMetadata from letta.schemas.message import Message +from letta.schemas.tool import Tool from letta.settings import TestSettings from .constants import TIMEOUT @@ -199,3 +201,21 @@ def wait_for_server(url, timeout=30, interval=0.5): def random_string(length: int) -> str: return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + +def create_tool_from_func( + func, + tags: Optional[List[str]] = None, + description: Optional[str] = None, +): + source_code = parse_source_code(func) + source_type = "python" + if not tags: + tags = [] + + return Tool( + source_type=source_type, + source_code=source_code, + tags=tags, + description=description, + )