fix: Fix test letta agent batch (#2295)
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -14,27 +12,26 @@ from anthropic.types.beta.messages import (
|
||||
BetaMessageBatchRequestCounts,
|
||||
BetaMessageBatchSucceededResult,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.config import LettaConfig
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentState, AgentStepState
|
||||
from letta.schemas.agent import AgentState, AgentStepState, CreateAgent
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
|
||||
from letta.schemas.job import BatchJob
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import LettaBatchRequest
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Test Constants
|
||||
# Test Constants / Helpers
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
# Model identifiers used in tests
|
||||
@@ -48,13 +45,31 @@ MODELS = {
|
||||
EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
|
||||
|
||||
|
||||
def create_tool_from_func(
|
||||
func,
|
||||
tags: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
if not tags:
|
||||
tags = []
|
||||
|
||||
return Tool(
|
||||
source_type=source_type,
|
||||
source_code=source_code,
|
||||
tags=tags,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Test Fixtures
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def weather_tool(client):
|
||||
def weather_tool(server):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
@@ -79,13 +94,14 @@ def weather_tool(client):
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=get_weather)
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def rethink_tool(client):
|
||||
def rethink_tool(server):
|
||||
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
|
||||
"""
|
||||
Re-evaluate the memory in block_name, integrating new and updated facts.
|
||||
@@ -101,28 +117,33 @@ def rethink_tool(client):
|
||||
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
|
||||
return None
|
||||
|
||||
tool = client.tools.upsert_from_function(func=rethink_memory)
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agents(client, weather_tool):
|
||||
def agents(server, weather_tool):
|
||||
"""
|
||||
Create three test agents with different models.
|
||||
|
||||
Returns:
|
||||
Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
|
||||
def create_agent(suffix, model_name):
|
||||
return client.agents.create(
|
||||
name=f"test_agent_{suffix}",
|
||||
include_base_tools=True,
|
||||
model=model_name,
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[weather_tool.id],
|
||||
return server.create_agent(
|
||||
CreateAgent(
|
||||
name=f"test_agent_{suffix}",
|
||||
include_base_tools=True,
|
||||
model=model_name,
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[weather_tool.id],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -290,32 +311,6 @@ def clear_batch_tables():
|
||||
session.commit()
|
||||
|
||||
|
||||
def run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""
|
||||
Ensures a server is running and returns its base URL.
|
||||
|
||||
Uses environment variable if available, otherwise starts a server
|
||||
in a background thread.
|
||||
"""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(url)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""
|
||||
@@ -324,14 +319,11 @@ def server():
|
||||
Loads and saves config to ensure proper initialization.
|
||||
"""
|
||||
config = LettaConfig.load()
|
||||
|
||||
config.save()
|
||||
return SyncServer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client connected to the test server."""
|
||||
return Letta(base_url=server_url)
|
||||
server = SyncServer(init_with_default_org_and_user=True)
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -368,23 +360,27 @@ class MockAsyncIterable:
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
|
||||
target_block_label = "human"
|
||||
new_memory = "banana"
|
||||
agent = client.agents.create(
|
||||
name=f"test_agent_rethink",
|
||||
include_base_tools=True,
|
||||
model=MODELS["sonnet"],
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[rethink_tool.id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": target_block_label,
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
],
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
agent = await server.create_agent_async(
|
||||
request=CreateAgent(
|
||||
name=f"test_agent_rethink",
|
||||
include_base_tools=True,
|
||||
model=MODELS["sonnet"],
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[rethink_tool.id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": target_block_label,
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agents = [agent]
|
||||
batch_requests = [
|
||||
@@ -444,13 +440,13 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# Check that the tool has been executed correctly
|
||||
agent = client.agents.retrieve(agent_id=agent.id)
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor)
|
||||
for block in agent.memory.blocks:
|
||||
if block.label == target_block_label:
|
||||
assert block.value == new_memory
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_partial_error_from_anthropic_batch(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@@ -610,7 +606,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
assert agent_messages[0].role == MessageRole.user, "Expected initial user message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_resume_step_some_stop(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@@ -773,7 +769,7 @@ def _assert_descending_order(messages):
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_resume_step_after_request_all_continue(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@@ -911,7 +907,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user