fix: Fix letta agent batch tests (#3524)

This commit is contained in:
Matthew Zhou
2025-07-23 18:15:45 -07:00
committed by GitHub
parent 440217062b
commit b9b109f586

View File

@@ -18,7 +18,6 @@ from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.config import LettaConfig
from letta.helpers import ToolRulesSolver
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
from letta.schemas.job import BatchJob
@@ -27,7 +26,6 @@ from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.llm_batch_job import AgentStepState
from letta.schemas.message import MessageCreate
from letta.schemas.tool_rule import InitToolRule
from letta.server.db import db_context
from letta.server.server import SyncServer
from tests.utils import create_tool_from_func
@@ -51,17 +49,25 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
# --------------------------------------------------------------------------- #
@pytest.fixture(scope="function")
@pytest.fixture(scope="module")
def event_loop():
"""Use a single asyncio loop for the entire test session."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def weather_tool(server):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
Parameters:
location (str): The location to get the weather for.
Args:
location: The location to get the weather for.
Returns:
str: A formatted string describing the weather in the given location.
A formatted string describing the weather in the given location.
Raises:
RuntimeError: If the request to fetch weather data fails.
@@ -83,7 +89,7 @@ def weather_tool(server):
yield tool
@pytest.fixture(scope="function")
@pytest.fixture
def rethink_tool(server):
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
"""
@@ -107,7 +113,7 @@ def rethink_tool(server):
@pytest.fixture
def agents(server, weather_tool):
async def agents(server, weather_tool):
"""
Create three test agents with different models.
@@ -116,8 +122,8 @@ def agents(server, weather_tool):
"""
actor = server.user_manager.get_user_or_default()
def create_agent(suffix, model_name):
return server.create_agent(
async def create_agent(suffix, model_name):
return await server.create_agent_async(
CreateAgent(
name=f"test_agent_{suffix}",
include_base_tools=True,
@@ -130,9 +136,9 @@ def agents(server, weather_tool):
)
return (
create_agent("sonnet", MODELS["sonnet"]),
create_agent("haiku", MODELS["haiku"]),
create_agent("opus", MODELS["opus"]),
await create_agent("sonnet", MODELS["sonnet"]),
await create_agent("haiku", MODELS["haiku"]),
await create_agent("opus", MODELS["opus"]),
)
@@ -283,18 +289,18 @@ def dummy_batch_response():
# Server and Database Management
# --------------------------------------------------------------------------- #
@pytest.fixture(autouse=True)
def clear_batch_tables():
"""Clear batch-related tables before each test."""
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables):
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
session.execute(table.delete()) # Truncate table
session.commit()
#
# @pytest.fixture(autouse=True)
# def clear_batch_tables():
# """Clear batch-related tables before each test."""
# with db_context() as session:
# for table in reversed(Base.metadata.sorted_tables):
# if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
# session.execute(table.delete()) # Truncate table
# session.commit()
@pytest.fixture(scope="module")
@pytest.fixture
def server():
"""
Creates a SyncServer instance for testing.
@@ -309,6 +315,20 @@ def server():
yield server
@pytest.fixture
def default_organization(server):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_default_organization()
yield org
@pytest.fixture
def default_user(server, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_default_user(org_id=default_organization.id)
yield user
@pytest.fixture
def batch_job(default_user, server):
job = BatchJob(
@@ -343,7 +363,7 @@ class MockAsyncIterable:
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.asyncio
async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
target_block_label = "human"
new_memory = "banana"
@@ -429,7 +449,7 @@ async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, defa
assert block.value == new_memory
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.asyncio
async def test_partial_error_from_anthropic_batch(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@@ -596,7 +616,7 @@ async def test_partial_error_from_anthropic_batch(
assert agent_messages[0].role == MessageRole.user, "Expected initial user message"
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.asyncio
async def test_resume_step_some_stop(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@@ -764,7 +784,7 @@ def _assert_descending_order(messages):
return True
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.asyncio
async def test_resume_step_after_request_all_continue(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@@ -907,7 +927,7 @@ async def test_resume_step_after_request_all_continue(
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.asyncio
async def test_step_until_request_prepares_and_submits_batch_correctly(
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
):