fix: Fix test letta agent batch (#2295)

This commit is contained in:
Matthew Zhou
2025-05-21 07:25:49 -07:00
committed by GitHub
parent 095a14cd1d
commit b7f2fb256a
2 changed files with 166 additions and 115 deletions

View File

@@ -1,7 +1,5 @@
import os
import threading
from datetime import datetime, timezone
from typing import Tuple
from typing import List, Optional, Tuple
from unittest.mock import AsyncMock, patch
import pytest
@@ -14,27 +12,26 @@ from anthropic.types.beta.messages import (
BetaMessageBatchRequestCounts,
BetaMessageBatchSucceededResult,
)
from dotenv import load_dotenv
from letta_client import Letta
from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.config import LettaConfig
from letta.functions.functions import parse_source_code
from letta.helpers import ToolRulesSolver
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.agent import AgentState, AgentStepState, CreateAgent
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
from letta.schemas.job import BatchJob
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.message import MessageCreate
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import InitToolRule
from letta.server.db import db_context
from letta.server.server import SyncServer
from tests.utils import wait_for_server
# --------------------------------------------------------------------------- #
# Test Constants
# Test Constants / Helpers
# --------------------------------------------------------------------------- #
# Model identifiers used in tests
@@ -48,13 +45,31 @@ MODELS = {
EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
def create_tool_from_func(
func,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
):
source_code = parse_source_code(func)
source_type = "python"
if not tags:
tags = []
return Tool(
source_type=source_type,
source_code=source_code,
tags=tags,
description=description,
)
# --------------------------------------------------------------------------- #
# Test Fixtures
# --------------------------------------------------------------------------- #
@pytest.fixture(scope="function")
def weather_tool(client):
def weather_tool(server):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
@@ -79,13 +94,14 @@ def weather_tool(client):
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
tool = client.tools.upsert_from_function(func=get_weather)
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def rethink_tool(client):
def rethink_tool(server):
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
"""
Re-evaluate the memory in block_name, integrating new and updated facts.
@@ -101,28 +117,33 @@ def rethink_tool(client):
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
return None
tool = client.tools.upsert_from_function(func=rethink_memory)
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor)
# Yield the created tool
yield tool
@pytest.fixture
def agents(client, weather_tool):
def agents(server, weather_tool):
"""
Create three test agents with different models.
Returns:
Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models
"""
actor = server.user_manager.get_user_or_default()
def create_agent(suffix, model_name):
return client.agents.create(
name=f"test_agent_{suffix}",
include_base_tools=True,
model=model_name,
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[weather_tool.id],
return server.create_agent(
CreateAgent(
name=f"test_agent_{suffix}",
include_base_tools=True,
model=model_name,
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[weather_tool.id],
),
actor=actor,
)
return (
@@ -290,32 +311,6 @@ def clear_batch_tables():
session.commit()
def run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""
Ensures a server is running and returns its base URL.
Uses environment variable if available, otherwise starts a server
in a background thread.
"""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(url)
return url
@pytest.fixture(scope="module")
def server():
"""
@@ -324,14 +319,11 @@ def server():
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
return SyncServer()
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client connected to the test server."""
return Letta(base_url=server_url)
server = SyncServer(init_with_default_org_and_user=True)
yield server
@pytest.fixture
@@ -368,23 +360,27 @@ class MockAsyncIterable:
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio(loop_scope="session")
async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
@pytest.mark.asyncio(loop_scope="module")
async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
target_block_label = "human"
new_memory = "banana"
agent = client.agents.create(
name=f"test_agent_rethink",
include_base_tools=True,
model=MODELS["sonnet"],
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[rethink_tool.id],
memory_blocks=[
{
"label": target_block_label,
"value": "Name: Matt",
},
],
actor = server.user_manager.get_user_or_default()
agent = await server.create_agent_async(
request=CreateAgent(
name=f"test_agent_rethink",
include_base_tools=True,
model=MODELS["sonnet"],
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[rethink_tool.id],
memory_blocks=[
{
"label": target_block_label,
"value": "Name: Matt",
},
],
),
actor=actor,
)
agents = [agent]
batch_requests = [
@@ -444,13 +440,13 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
await poll_running_llm_batches(server)
# Check that the tool has been executed correctly
agent = client.agents.retrieve(agent_id=agent.id)
agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor)
for block in agent.memory.blocks:
if block.label == target_block_label:
assert block.value == new_memory
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_partial_error_from_anthropic_batch(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@@ -610,7 +606,7 @@ async def test_partial_error_from_anthropic_batch(
assert agent_messages[0].role == MessageRole.user, "Expected initial user message"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_resume_step_some_stop(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@@ -773,7 +769,7 @@ def _assert_descending_order(messages):
return True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_resume_step_after_request_all_continue(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@@ -911,7 +907,7 @@ async def test_resume_step_after_request_all_continue(
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_step_until_request_prepares_and_submits_batch_correctly(
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
):