From c31efe1517a81ec64a74582e2d4ea2cd2ada389e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 24 Sep 2025 16:08:47 -0700 Subject: [PATCH] fix: Fix test letta agent batch (#4918) Test letta agent batch --- tests/test_letta_agent_batch.py | 43 +++++++++++++++++---------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index afdb2f79..b7845af8 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -50,7 +50,7 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"] @pytest.fixture -def weather_tool(server): +async def weather_tool(server): def get_weather(location: str) -> str: """ Fetches the current weather for a given location. @@ -75,14 +75,14 @@ def weather_tool(server): else: raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") - actor = server.user_manager.get_user_or_default() - tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor) + actor = await server.user_manager.get_actor_or_default_async() + tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=get_weather), actor=actor) # Yield the created tool yield tool @pytest.fixture -def rethink_tool(server): +async def rethink_tool(server): def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore """ Re-evaluate the memory in block_name, integrating new and updated facts. @@ -98,8 +98,8 @@ def rethink_tool(server): agent_state.memory.update_block_value(label=target_block_label, value=new_memory) return None - actor = server.user_manager.get_user_or_default() - tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor) + actor = await server.user_manager.get_actor_or_default_async() + tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=rethink_memory), actor=actor) # Yield the created tool yield tool @@ -112,7 +112,7 @@ async def agents(server, weather_tool): Returns: Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models """ - actor = server.user_manager.get_user_or_default() + actor = await server.user_manager.get_actor_or_default_async() async def create_agent(suffix, model_name): return await server.create_agent_async( @@ -293,7 +293,7 @@ def dummy_batch_response(): @pytest.fixture -def server(): +async def server(): """ Creates a SyncServer instance for testing. @@ -304,25 +304,26 @@ def server(): config.save() server = SyncServer(init_with_default_org_and_user=True) + await server.init_async(init_with_default_org_and_user=True) yield server @pytest.fixture -def default_organization(server): +async def default_organization(server): """Fixture to create and return the default organization.""" - org = server.organization_manager.create_default_organization() + org = await server.organization_manager.get_default_organization_async() yield org @pytest.fixture -def default_user(server, default_organization): +async def default_user(server, default_organization): """Fixture to create and return the default user within the default organization.""" - user = server.user_manager.create_default_user(org_id=default_organization.id) + user = await server.user_manager.get_default_actor_async() yield user @pytest.fixture -def batch_job(default_user, server): +async def batch_job(default_user, server): job = BatchJob( user_id=default_user.id, status=JobStatus.created, @@ -330,11 +331,11 @@ def batch_job(default_user, server): "job_type": "batch_messages", }, ) - job = server.job_manager.create_job(pydantic_job=job, actor=default_user) + job = await server.job_manager.create_job_async(pydantic_job=job, actor=default_user) yield job # cleanup - server.job_manager.delete_job_by_id(job.id, actor=default_user) + await server.job_manager.delete_job_by_id_async(job.id, actor=default_user) class MockAsyncIterable: @@ -359,7 +360,7 @@ class MockAsyncIterable: async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool): target_block_label = "human" new_memory = "banana" - actor = server.user_manager.get_user_or_default() + actor = await server.user_manager.get_actor_or_default_async() agent = await server.create_agent_async( request=CreateAgent( name="test_agent_rethink", @@ -435,7 +436,7 @@ async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, defa await poll_running_llm_batches(server) # Check that the tool has been executed correctly - agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor) + agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=actor) for block in agent.memory.blocks: if block.label == target_block_label: assert block.value == new_memory @@ -577,7 +578,7 @@ async def test_partial_error_from_anthropic_batch( # Check that agent states have been properly modified to have extended in-context messages for agent in agents: - refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user) + refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user) if refreshed_agent.id == agents_failed[0].id: assert len(refreshed_agent.message_ids) == 4, ( f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}" @@ -734,7 +735,7 @@ async def test_resume_step_some_stop( # Check that agent states have been properly modified to have extended in-context messages for agent in agents: - refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user) + refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user) assert len(refreshed_agent.message_ids) == 6, ( f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" ) @@ -894,7 +895,7 @@ async def test_resume_step_after_request_all_continue( # Check that agent states have been properly modified to have extended in-context messages for agent in agents: - refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user) + refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user) assert len(refreshed_agent.message_ids) == 6, ( f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" ) @@ -940,7 +941,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly( # Set up spy function for the Anthropic client with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async") as mock_send: # Configure mock to validate input and return dummy response - async def validate_batch_request(*, agent_messages_mapping, agent_tools_mapping, agent_llm_config_mapping): + async def validate_batch_request(*, agent_type, agent_messages_mapping, agent_tools_mapping, agent_llm_config_mapping): # Verify all agent IDs are present in all mappings expected_ids = sorted(expected_models.keys()) actual_ids = sorted(agent_messages_mapping.keys())