fix: Fix letta agent batch tests (#3524)
This commit is contained in:
@@ -18,7 +18,6 @@ from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.config import LettaConfig
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
|
||||
from letta.schemas.job import BatchJob
|
||||
@@ -27,7 +26,6 @@ from letta.schemas.letta_request import LettaBatchRequest
|
||||
from letta.schemas.llm_batch_job import AgentStepState
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
from tests.utils import create_tool_from_func
|
||||
|
||||
@@ -51,17 +49,25 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
"""Use a single asyncio loop for the entire test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool(server):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
|
||||
Parameters:
|
||||
location (str): The location to get the weather for.
|
||||
Args:
|
||||
location: The location to get the weather for.
|
||||
|
||||
Returns:
|
||||
str: A formatted string describing the weather in the given location.
|
||||
A formatted string describing the weather in the given location.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the request to fetch weather data fails.
|
||||
@@ -83,7 +89,7 @@ def weather_tool(server):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture
|
||||
def rethink_tool(server):
|
||||
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
|
||||
"""
|
||||
@@ -107,7 +113,7 @@ def rethink_tool(server):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agents(server, weather_tool):
|
||||
async def agents(server, weather_tool):
|
||||
"""
|
||||
Create three test agents with different models.
|
||||
|
||||
@@ -116,8 +122,8 @@ def agents(server, weather_tool):
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
|
||||
def create_agent(suffix, model_name):
|
||||
return server.create_agent(
|
||||
async def create_agent(suffix, model_name):
|
||||
return await server.create_agent_async(
|
||||
CreateAgent(
|
||||
name=f"test_agent_{suffix}",
|
||||
include_base_tools=True,
|
||||
@@ -130,9 +136,9 @@ def agents(server, weather_tool):
|
||||
)
|
||||
|
||||
return (
|
||||
create_agent("sonnet", MODELS["sonnet"]),
|
||||
create_agent("haiku", MODELS["haiku"]),
|
||||
create_agent("opus", MODELS["opus"]),
|
||||
await create_agent("sonnet", MODELS["sonnet"]),
|
||||
await create_agent("haiku", MODELS["haiku"]),
|
||||
await create_agent("opus", MODELS["opus"]),
|
||||
)
|
||||
|
||||
|
||||
@@ -283,18 +289,18 @@ def dummy_batch_response():
|
||||
# Server and Database Management
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_batch_tables():
|
||||
"""Clear batch-related tables before each test."""
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables):
|
||||
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
#
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def clear_batch_tables():
|
||||
# """Clear batch-related tables before each test."""
|
||||
# with db_context() as session:
|
||||
# for table in reversed(Base.metadata.sorted_tables):
|
||||
# if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
|
||||
# session.execute(table.delete()) # Truncate table
|
||||
# session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture
|
||||
def server():
|
||||
"""
|
||||
Creates a SyncServer instance for testing.
|
||||
@@ -309,6 +315,20 @@ def server():
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_organization(server):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_default_user(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_job(default_user, server):
|
||||
job = BatchJob(
|
||||
@@ -343,7 +363,7 @@ class MockAsyncIterable:
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.asyncio
|
||||
async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
|
||||
target_block_label = "human"
|
||||
new_memory = "banana"
|
||||
@@ -429,7 +449,7 @@ async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, defa
|
||||
assert block.value == new_memory
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_error_from_anthropic_batch(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@@ -596,7 +616,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
assert agent_messages[0].role == MessageRole.user, "Expected initial user message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_some_stop(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@@ -764,7 +784,7 @@ def _assert_descending_order(messages):
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_after_request_all_continue(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@@ -907,7 +927,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user