committed by
Caren Thomas
parent
d78e0ccb58
commit
c31efe1517
@@ -50,7 +50,7 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool(server):
|
||||
async def weather_tool(server):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
@@ -75,14 +75,14 @@ def weather_tool(server):
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async()
|
||||
tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=get_weather), actor=actor)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rethink_tool(server):
|
||||
async def rethink_tool(server):
|
||||
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
|
||||
"""
|
||||
Re-evaluate the memory in block_name, integrating new and updated facts.
|
||||
@@ -98,8 +98,8 @@ def rethink_tool(server):
|
||||
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
|
||||
return None
|
||||
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async()
|
||||
tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=rethink_memory), actor=actor)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
@@ -112,7 +112,7 @@ async def agents(server, weather_tool):
|
||||
Returns:
|
||||
Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
actor = await server.user_manager.get_actor_or_default_async()
|
||||
|
||||
async def create_agent(suffix, model_name):
|
||||
return await server.create_agent_async(
|
||||
@@ -293,7 +293,7 @@ def dummy_batch_response():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
async def server():
|
||||
"""
|
||||
Creates a SyncServer instance for testing.
|
||||
|
||||
@@ -304,25 +304,26 @@ def server():
|
||||
config.save()
|
||||
|
||||
server = SyncServer(init_with_default_org_and_user=True)
|
||||
await server.init_async(init_with_default_org_and_user=True)
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_organization(server):
|
||||
async def default_organization(server):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_default_organization()
|
||||
org = await server.organization_manager.get_default_organization_async()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server, default_organization):
|
||||
async def default_user(server, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_default_user(org_id=default_organization.id)
|
||||
user = await server.user_manager.get_default_actor_async()
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_job(default_user, server):
|
||||
async def batch_job(default_user, server):
|
||||
job = BatchJob(
|
||||
user_id=default_user.id,
|
||||
status=JobStatus.created,
|
||||
@@ -330,11 +331,11 @@ def batch_job(default_user, server):
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
)
|
||||
job = server.job_manager.create_job(pydantic_job=job, actor=default_user)
|
||||
job = await server.job_manager.create_job_async(pydantic_job=job, actor=default_user)
|
||||
yield job
|
||||
|
||||
# cleanup
|
||||
server.job_manager.delete_job_by_id(job.id, actor=default_user)
|
||||
await server.job_manager.delete_job_by_id_async(job.id, actor=default_user)
|
||||
|
||||
|
||||
class MockAsyncIterable:
|
||||
@@ -359,7 +360,7 @@ class MockAsyncIterable:
|
||||
async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
|
||||
target_block_label = "human"
|
||||
new_memory = "banana"
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
actor = await server.user_manager.get_actor_or_default_async()
|
||||
agent = await server.create_agent_async(
|
||||
request=CreateAgent(
|
||||
name="test_agent_rethink",
|
||||
@@ -435,7 +436,7 @@ async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, defa
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# Check that the tool has been executed correctly
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=actor)
|
||||
for block in agent.memory.blocks:
|
||||
if block.label == target_block_label:
|
||||
assert block.value == new_memory
|
||||
@@ -577,7 +578,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
|
||||
# Check that agent states have been properly modified to have extended in-context messages
|
||||
for agent in agents:
|
||||
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
|
||||
refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
|
||||
if refreshed_agent.id == agents_failed[0].id:
|
||||
assert len(refreshed_agent.message_ids) == 4, (
|
||||
f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
@@ -734,7 +735,7 @@ async def test_resume_step_some_stop(
|
||||
|
||||
# Check that agent states have been properly modified to have extended in-context messages
|
||||
for agent in agents:
|
||||
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
|
||||
refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
|
||||
assert len(refreshed_agent.message_ids) == 6, (
|
||||
f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
)
|
||||
@@ -894,7 +895,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
|
||||
# Check that agent states have been properly modified to have extended in-context messages
|
||||
for agent in agents:
|
||||
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
|
||||
refreshed_agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
|
||||
assert len(refreshed_agent.message_ids) == 6, (
|
||||
f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
)
|
||||
@@ -940,7 +941,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
# Set up spy function for the Anthropic client
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async") as mock_send:
|
||||
# Configure mock to validate input and return dummy response
|
||||
async def validate_batch_request(*, agent_messages_mapping, agent_tools_mapping, agent_llm_config_mapping):
|
||||
async def validate_batch_request(*, agent_type, agent_messages_mapping, agent_tools_mapping, agent_llm_config_mapping):
|
||||
# Verify all agent IDs are present in all mappings
|
||||
expected_ids = sorted(expected_models.keys())
|
||||
actual_ids = sorted(agent_messages_mapping.keys())
|
||||
|
||||
Reference in New Issue
Block a user