From b9b109f586c6bb39ab2bf62204e5d2839aee1be4 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 23 Jul 2025 18:15:45 -0700 Subject: [PATCH] fix: Fix letta agent batch tests (#3524) --- tests/test_letta_agent_batch.py | 76 +++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 40d68170..6507a4ee 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -18,7 +18,6 @@ from letta.agents.letta_agent_batch import LettaAgentBatch from letta.config import LettaConfig from letta.helpers import ToolRulesSolver from letta.jobs.llm_batch_job_polling import poll_running_llm_batches -from letta.orm import Base from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType from letta.schemas.job import BatchJob @@ -27,7 +26,6 @@ from letta.schemas.letta_request import LettaBatchRequest from letta.schemas.llm_batch_job import AgentStepState from letta.schemas.message import MessageCreate from letta.schemas.tool_rule import InitToolRule -from letta.server.db import db_context from letta.server.server import SyncServer from tests.utils import create_tool_from_func @@ -51,17 +49,25 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"] # --------------------------------------------------------------------------- # -@pytest.fixture(scope="function") +@pytest.fixture(scope="module") +def event_loop(): + """Use a single asyncio loop for the entire test session.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture def weather_tool(server): def get_weather(location: str) -> str: """ Fetches the current weather for a given location. - Parameters: - location (str): The location to get the weather for. + Args: + location: The location to get the weather for. Returns: - str: A formatted string describing the weather in the given location. + A formatted string describing the weather in the given location. Raises: RuntimeError: If the request to fetch weather data fails. @@ -83,7 +89,7 @@ def weather_tool(server): yield tool -@pytest.fixture(scope="function") +@pytest.fixture def rethink_tool(server): def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore """ @@ -107,7 +113,7 @@ def rethink_tool(server): @pytest.fixture -def agents(server, weather_tool): +async def agents(server, weather_tool): """ Create three test agents with different models. @@ -116,8 +122,8 @@ def agents(server, weather_tool): """ actor = server.user_manager.get_user_or_default() - def create_agent(suffix, model_name): - return server.create_agent( + async def create_agent(suffix, model_name): + return await server.create_agent_async( CreateAgent( name=f"test_agent_{suffix}", include_base_tools=True, @@ -130,9 +136,9 @@ def agents(server, weather_tool): ) return ( - create_agent("sonnet", MODELS["sonnet"]), - create_agent("haiku", MODELS["haiku"]), - create_agent("opus", MODELS["opus"]), + await create_agent("sonnet", MODELS["sonnet"]), + await create_agent("haiku", MODELS["haiku"]), + await create_agent("opus", MODELS["opus"]), ) @@ -283,18 +289,18 @@ def dummy_batch_response(): # Server and Database Management # --------------------------------------------------------------------------- # - -@pytest.fixture(autouse=True) -def clear_batch_tables(): - """Clear batch-related tables before each test.""" - with db_context() as session: - for table in reversed(Base.metadata.sorted_tables): - if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}: - session.execute(table.delete()) # Truncate table - session.commit() +# +# @pytest.fixture(autouse=True) +# def clear_batch_tables(): +# """Clear batch-related tables before each test.""" +# with db_context() as session: +# for table in reversed(Base.metadata.sorted_tables): +# if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}: +# session.execute(table.delete()) # Truncate table +# session.commit() -@pytest.fixture(scope="module") +@pytest.fixture def server(): """ Creates a SyncServer instance for testing. @@ -309,6 +315,20 @@ def server(): yield server +@pytest.fixture +def default_organization(server): + """Fixture to create and return the default organization.""" + org = server.organization_manager.create_default_organization() + yield org + + +@pytest.fixture +def default_user(server, default_organization): + """Fixture to create and return the default user within the default organization.""" + user = server.user_manager.create_default_user(org_id=default_organization.id) + yield user + + @pytest.fixture def batch_job(default_user, server): job = BatchJob( @@ -343,7 +363,7 @@ class MockAsyncIterable: # --------------------------------------------------------------------------- # -@pytest.mark.asyncio(loop_scope="module") +@pytest.mark.asyncio async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool): target_block_label = "human" new_memory = "banana" @@ -429,7 +449,7 @@ async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, defa assert block.value == new_memory -@pytest.mark.asyncio(loop_scope="module") +@pytest.mark.asyncio async def test_partial_error_from_anthropic_batch( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -596,7 +616,7 @@ async def test_partial_error_from_anthropic_batch( assert agent_messages[0].role == MessageRole.user, "Expected initial user message" -@pytest.mark.asyncio(loop_scope="module") +@pytest.mark.asyncio async def test_resume_step_some_stop( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -764,7 +784,7 @@ def _assert_descending_order(messages): return True -@pytest.mark.asyncio(loop_scope="module") +@pytest.mark.asyncio async def test_resume_step_after_request_all_continue( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -907,7 +927,7 @@ async def test_resume_step_after_request_all_continue( assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message" -@pytest.mark.asyncio(loop_scope="module") +@pytest.mark.asyncio async def test_step_until_request_prepares_and_submits_batch_correctly( disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job ):