fix: Fix test letta agent batch (#4918)

Test letta agent batch
This commit is contained in:
Matthew Zhou
2025-09-24 16:08:47 -07:00
committed by Caren Thomas
parent d78e0ccb58
commit c31efe1517

View File

@@ -50,7 +50,7 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
@pytest.fixture
def weather_tool(server):
async def weather_tool(server):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
@@ -75,14 +75,14 @@ def weather_tool(server):
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
actor = await server.user_manager.get_actor_or_default_async()
tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=get_weather), actor=actor)
# Yield the created tool
yield tool
@pytest.fixture
def rethink_tool(server):
async def rethink_tool(server):
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
"""
Re-evaluate the memory in block_name, integrating new and updated facts.
@@ -98,8 +98,8 @@ def rethink_tool(server):
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
return None
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor)
actor = await server.user_manager.get_actor_or_default_async()
tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=rethink_memory), actor=actor)
# Yield the created tool
yield tool
@@ -112,7 +112,7 @@ async def agents(server, weather_tool):
Returns:
Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models
"""
actor = server.user_manager.get_user_or_default()
actor = await server.user_manager.get_actor_or_default_async()
async def create_agent(suffix, model_name):
return await server.create_agent_async(
@@ -293,7 +293,7 @@ def dummy_batch_response():
@pytest.fixture
def server():
async def server():
"""
Creates a SyncServer instance for testing.
@@ -304,25 +304,26 @@ def server():
config.save()
server = SyncServer(init_with_default_org_and_user=True)
await server.init_async(init_with_default_org_and_user=True)
yield server
@pytest.fixture
def default_organization(server):
async def default_organization(server):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_default_organization()
org = await server.organization_manager.get_default_organization_async()
yield org
@pytest.fixture
def default_user(server, default_organization):
async def default_user(server, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_default_user(org_id=default_organization.id)
user = await server.user_manager.get_default_actor_async()
yield user
@pytest.fixture
def batch_job(default_user, server):
async def batch_job(default_user, server):
job = BatchJob(
user_id=default_user.id,
status=JobStatus.created,
@@ -330,11 +331,11 @@ def batch_job(default_user, server):
"job_type": "batch_messages",
},
)
job = server.job_manager.create_job(pydantic_job=job, actor=default_user)
job = await server.job_manager.create_job_async(pydantic_job=job, actor=default_user)
yield job
# cleanup
server.job_manager.delete_job_by_id(job.id, actor=default_user)
await server.job_manager.delete_job_by_id_async(job.id, actor=default_user)
class MockAsyncIterable:
@@ -359,7 +360,7 @@ class MockAsyncIterable:
async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
target_block_label = "human"
new_memory = "banana"
actor = server.user_manager.get_user_or_default()
actor = await server.user_manager.get_actor_or_default_async()
agent = await server.create_agent_async(
request=CreateAgent(
name="test_agent_rethink",
@@ -435,7 +436,7 @@ async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, defa
await poll_running_llm_batches(server)
# Check that the tool has been executed correctly
agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor)
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=actor)
for block in agent.memory.blocks:
if block.label == target_block_label:
assert block.value == new_memory
@@ -577,7 +578,7 @@ async def test_partial_error_from_anthropic_batch(
# Check that agent states have been properly modified to have extended in-context messages
for agent in agents:
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
if refreshed_agent.id == agents_failed[0].id:
assert len(refreshed_agent.message_ids) == 4, (
f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
@@ -734,7 +735,7 @@ async def test_resume_step_some_stop(
# Check that agent states have been properly modified to have extended in-context messages
for agent in agents:
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
assert len(refreshed_agent.message_ids) == 6, (
f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
)
@@ -894,7 +895,7 @@ async def test_resume_step_after_request_all_continue(
# Check that agent states have been properly modified to have extended in-context messages
for agent in agents:
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
assert len(refreshed_agent.message_ids) == 6, (
f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
)
@@ -940,7 +941,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
# Set up spy function for the Anthropic client
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async") as mock_send:
# Configure mock to validate input and return dummy response
async def validate_batch_request(*, agent_messages_mapping, agent_tools_mapping, agent_llm_config_mapping):
async def validate_batch_request(*, agent_type, agent_messages_mapping, agent_tools_mapping, agent_llm_config_mapping):
# Verify all agent IDs are present in all mappings
expected_ids = sorted(expected_models.keys())
actual_ids = sorted(agent_messages_mapping.keys())