test: update test_managers with async functions (#2243)
This commit is contained in:
@@ -185,6 +185,16 @@ class ToolManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def get_tool_id_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
|
||||
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
|
||||
try:
|
||||
async with db_registry.async_session() as session:
|
||||
tool = await ToolModel.read_async(db_session=session, name=tool_name, actor=actor)
|
||||
return tool.id
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination."""
|
||||
@@ -280,7 +290,8 @@ class ToolManager:
|
||||
tool.tool_type = updated_tool_type
|
||||
|
||||
# Save the updated tool to the database
|
||||
return await tool.update_async(db_session=session, actor=actor).to_pydantic()
|
||||
tool = await tool.update_async(db_session=session, actor=actor)
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
|
||||
|
||||
@@ -69,7 +69,7 @@ from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.server.db import db_context
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
@@ -92,14 +92,14 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_tables():
|
||||
with db_context() as session:
|
||||
async def _clear_tables():
|
||||
async with db_registry.async_session() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
# If this is the block_history table, skip it
|
||||
if table.name == "block_history":
|
||||
continue
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
await session.execute(table.delete()) # Truncate table
|
||||
await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -171,7 +171,7 @@ def default_file(server: SyncServer, default_source, default_user, default_organ
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def print_tool(server: SyncServer, default_user, default_organization):
|
||||
async def print_tool(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
|
||||
def print_tool(message: str):
|
||||
@@ -199,7 +199,7 @@ def print_tool(server: SyncServer, default_user, default_organization):
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
tool = server.tool_manager.create_tool(tool, actor=default_user)
|
||||
tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
@@ -237,24 +237,24 @@ def mcp_tool(server, default_user):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_job(server: SyncServer, default_user):
|
||||
async def default_job(server: SyncServer, default_user):
|
||||
"""Fixture to create and return a default job."""
|
||||
job_pydantic = PydanticJob(
|
||||
user_id=default_user.id,
|
||||
status=JobStatus.pending,
|
||||
)
|
||||
job = server.job_manager.create_job(pydantic_job=job_pydantic, actor=default_user)
|
||||
job = await server.job_manager.create_job_async(pydantic_job=job_pydantic, actor=default_user)
|
||||
yield job
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_run(server: SyncServer, default_user):
|
||||
async def default_run(server: SyncServer, default_user):
|
||||
"""Fixture to create and return a default job."""
|
||||
run_pydantic = PydanticRun(
|
||||
user_id=default_user.id,
|
||||
status=JobStatus.pending,
|
||||
)
|
||||
run = server.job_manager.create_job(pydantic_job=run_pydantic, actor=default_user)
|
||||
run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user)
|
||||
yield run
|
||||
|
||||
|
||||
@@ -403,7 +403,7 @@ def other_block(server: SyncServer, default_user):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_tool(server: SyncServer, default_user, default_organization):
|
||||
async def other_tool(server: SyncServer, default_user, default_organization):
|
||||
def print_other_tool(message: str):
|
||||
"""
|
||||
Args:
|
||||
@@ -428,16 +428,16 @@ def other_tool(server: SyncServer, default_user, default_organization):
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
tool = server.tool_manager.create_tool(tool, actor=default_user)
|
||||
tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
async def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_state = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="sarah_agent",
|
||||
memory_blocks=[],
|
||||
@@ -451,9 +451,9 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
async def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_state = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="charles_agent",
|
||||
memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")],
|
||||
@@ -467,7 +467,7 @@ def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_tool, default_source, default_block):
|
||||
async def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_tool, default_source, default_block):
|
||||
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]
|
||||
create_agent_request = CreateAgent(
|
||||
system="test system",
|
||||
@@ -486,7 +486,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
|
||||
message_buffer_autoclear=True,
|
||||
include_base_tools=False,
|
||||
)
|
||||
created_agent = server.agent_manager.create_agent(
|
||||
created_agent = await server.agent_manager.create_agent_async(
|
||||
create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -550,9 +550,9 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_with_tags(server: SyncServer, default_user):
|
||||
async def agent_with_tags(server: SyncServer, default_user):
|
||||
"""Fixture to create agents with specific tags."""
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent1",
|
||||
tags=["primary_agent", "benefit_1"],
|
||||
@@ -564,7 +564,7 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent2",
|
||||
tags=["primary_agent", "benefit_2"],
|
||||
@@ -576,7 +576,7 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent3 = server.agent_manager.create_agent(
|
||||
agent3 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent3",
|
||||
tags=["primary_agent", "benefit_1", "benefit_2"],
|
||||
@@ -656,17 +656,18 @@ async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agen
|
||||
comprehensive_agent_checks(get_agent_name, create_agent_request, actor=default_user)
|
||||
|
||||
# Test list agent
|
||||
list_agents = server.agent_manager.list_agents(actor=default_user)
|
||||
list_agents = await server.agent_manager.list_agents_async(actor=default_user)
|
||||
assert len(list_agents) == 1
|
||||
comprehensive_agent_checks(list_agents[0], create_agent_request, actor=default_user)
|
||||
|
||||
# Test deleting the agent
|
||||
server.agent_manager.delete_agent(get_agent.id, default_user)
|
||||
list_agents = server.agent_manager.list_agents(actor=default_user)
|
||||
list_agents = await server.agent_manager.list_agents_async(actor=default_user)
|
||||
assert len(list_agents) == 0
|
||||
|
||||
|
||||
def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block, event_loop):
|
||||
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]
|
||||
create_agent_request = CreateAgent(
|
||||
system="test system",
|
||||
@@ -679,12 +680,12 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||
include_base_tools=False,
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_state = await server.agent_manager.create_agent_async(
|
||||
create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 2
|
||||
init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user)
|
||||
assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 2
|
||||
init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user)
|
||||
|
||||
# Check that the system appears in the first initial message
|
||||
assert create_agent_request.system in init_messages[0].content[0].text
|
||||
@@ -694,7 +695,8 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
|
||||
assert create_agent_request.initial_message_sequence[0].content in init_messages[1].content[0].text
|
||||
|
||||
|
||||
def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block, event_loop):
|
||||
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]
|
||||
create_agent_request = CreateAgent(
|
||||
system="test system",
|
||||
@@ -706,18 +708,19 @@ def test_create_agent_default_initial_message(server: SyncServer, default_user,
|
||||
description="test_description",
|
||||
include_base_tools=False,
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_state = await server.agent_manager.create_agent_async(
|
||||
create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 4
|
||||
init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user)
|
||||
assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 4
|
||||
init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user)
|
||||
# Check that the system appears in the first initial message
|
||||
assert create_agent_request.system in init_messages[0].content[0].text
|
||||
assert create_agent_request.memory_blocks[0].value in init_messages[0].content[0].text
|
||||
|
||||
|
||||
def test_create_agent_with_json_in_system_message(server: SyncServer, default_user, default_block):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_with_json_in_system_message(server: SyncServer, default_user, default_block, event_loop):
|
||||
system_prompt = (
|
||||
"You are an expert teaching agent with encyclopedic knowledge. "
|
||||
"When you receive a topic, query the external database for more "
|
||||
@@ -734,19 +737,22 @@ def test_create_agent_with_json_in_system_message(server: SyncServer, default_us
|
||||
description="test_description",
|
||||
include_base_tools=False,
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_state = await server.agent_manager.create_agent_async(
|
||||
create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
assert agent_state is not None
|
||||
system_message_id = agent_state.message_ids[0]
|
||||
system_message = server.message_manager.get_message_by_id(message_id=system_message_id, actor=default_user)
|
||||
system_message = await server.message_manager.get_message_by_id_async(message_id=system_message_id, actor=default_user)
|
||||
assert system_prompt in system_message.content[0].text
|
||||
assert default_block.value in system_message.content[0].text
|
||||
server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user)
|
||||
|
||||
|
||||
def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent(
|
||||
server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user, event_loop
|
||||
):
|
||||
agent, _ = comprehensive_test_agent_fixture
|
||||
update_agent_request = UpdateAgent(
|
||||
name="train_agent",
|
||||
@@ -766,7 +772,7 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
|
||||
)
|
||||
|
||||
last_updated_timestamp = agent.updated_at
|
||||
updated_agent = server.agent_manager.update_agent(agent.id, update_agent_request, actor=default_user)
|
||||
updated_agent = await server.agent_manager.update_agent_async(agent.id, update_agent_request, actor=default_user)
|
||||
comprehensive_agent_checks(updated_agent, update_agent_request, actor=default_user)
|
||||
assert updated_agent.message_ids == update_agent_request.message_ids
|
||||
assert updated_agent.updated_at > last_updated_timestamp
|
||||
@@ -777,12 +783,13 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop):
|
||||
# Create an agent using the comprehensive fixture.
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# List agents using an empty list for select_fields.
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=[])
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=[])
|
||||
# Assert that the agent is returned and basic fields are present.
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
@@ -794,12 +801,13 @@ def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_
|
||||
assert len(agent.tags) == 0
|
||||
|
||||
|
||||
def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop):
|
||||
# Create an agent using the comprehensive fixture.
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# List agents using an empty list for select_fields.
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=None)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=None)
|
||||
# Assert that the agent is returned and basic fields are present.
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
@@ -811,12 +819,13 @@ def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_a
|
||||
assert len(agent.tags) > 0
|
||||
|
||||
|
||||
def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Choose a subset of valid relationship fields.
|
||||
valid_fields = ["tools", "tags"]
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=valid_fields)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=valid_fields)
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
# Depending on your to_pydantic() implementation,
|
||||
@@ -827,13 +836,14 @@ def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_te
|
||||
assert not agent.memory.blocks
|
||||
|
||||
|
||||
def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Provide field names that are not recognized.
|
||||
invalid_fields = ["foobar", "nonexistent_field"]
|
||||
# The expectation is that these fields are simply ignored.
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=invalid_fields)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=invalid_fields)
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
# Verify that standard fields are still present.c
|
||||
@@ -841,12 +851,13 @@ def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_tes
|
||||
assert agent.name is not None
|
||||
|
||||
|
||||
def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Provide duplicate valid field names.
|
||||
duplicate_fields = ["tools", "tools", "tags", "tags"]
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=duplicate_fields)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=duplicate_fields)
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
# Verify that the agent pydantic representation includes the relationships.
|
||||
@@ -855,12 +866,13 @@ def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_
|
||||
assert isinstance(agent.tags, list)
|
||||
|
||||
|
||||
def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Mix valid fields with an invalid one.
|
||||
mixed_fields = ["tools", "invalid_field"]
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=mixed_fields)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=mixed_fields)
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
# Valid fields should be loaded and accessible.
|
||||
@@ -870,9 +882,10 @@ def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_
|
||||
assert not hasattr(agent, "invalid_field")
|
||||
|
||||
|
||||
def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_ascending(server: SyncServer, default_user, event_loop):
|
||||
# Create two agents with known names
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_oldest",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
@@ -886,7 +899,7 @@ def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_newest",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
@@ -897,14 +910,15 @@ def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agents = server.agent_manager.list_agents(actor=default_user, ascending=True)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, ascending=True)
|
||||
names = [agent.name for agent in agents]
|
||||
assert names.index("agent_oldest") < names.index("agent_newest")
|
||||
|
||||
|
||||
def test_list_agents_descending(server: SyncServer, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_descending(server: SyncServer, default_user, event_loop):
|
||||
# Create two agents with known names
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_oldest",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
@@ -918,7 +932,7 @@ def test_list_agents_descending(server: SyncServer, default_user):
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent_newest",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
@@ -929,18 +943,19 @@ def test_list_agents_descending(server: SyncServer, default_user):
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agents = server.agent_manager.list_agents(actor=default_user, ascending=False)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, ascending=False)
|
||||
names = [agent.name for agent in agents]
|
||||
assert names.index("agent_newest") < names.index("agent_oldest")
|
||||
|
||||
|
||||
def test_list_agents_ordering_and_pagination(server: SyncServer, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_ordering_and_pagination(server: SyncServer, default_user, event_loop):
|
||||
names = ["alpha_agent", "beta_agent", "gamma_agent"]
|
||||
created_agents = []
|
||||
|
||||
# Create agents in known order
|
||||
for name in names:
|
||||
agent = server.agent_manager.create_agent(
|
||||
agent = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name=name,
|
||||
memory_blocks=[],
|
||||
@@ -957,17 +972,17 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user):
|
||||
agent_ids = {agent.name: agent.id for agent in created_agents}
|
||||
|
||||
# Ascending (oldest to newest)
|
||||
agents_asc = server.agent_manager.list_agents(actor=default_user, ascending=True)
|
||||
agents_asc = await server.agent_manager.list_agents_async(actor=default_user, ascending=True)
|
||||
asc_names = [agent.name for agent in agents_asc]
|
||||
assert asc_names.index("alpha_agent") < asc_names.index("beta_agent") < asc_names.index("gamma_agent")
|
||||
|
||||
# Descending (newest to oldest)
|
||||
agents_desc = server.agent_manager.list_agents(actor=default_user, ascending=False)
|
||||
agents_desc = await server.agent_manager.list_agents_async(actor=default_user, ascending=False)
|
||||
desc_names = [agent.name for agent in agents_desc]
|
||||
assert desc_names.index("gamma_agent") < desc_names.index("beta_agent") < desc_names.index("alpha_agent")
|
||||
|
||||
# After: Get agents after alpha_agent in ascending order (should exclude alpha)
|
||||
after_alpha = server.agent_manager.list_agents(actor=default_user, after=agent_ids["alpha_agent"], ascending=True)
|
||||
after_alpha = await server.agent_manager.list_agents_async(actor=default_user, after=agent_ids["alpha_agent"], ascending=True)
|
||||
after_names = [a.name for a in after_alpha]
|
||||
assert "alpha_agent" not in after_names
|
||||
assert "beta_agent" in after_names
|
||||
@@ -975,7 +990,7 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user):
|
||||
assert after_names == ["beta_agent", "gamma_agent"]
|
||||
|
||||
# Before: Get agents before gamma_agent in ascending order (should exclude gamma)
|
||||
before_gamma = server.agent_manager.list_agents(actor=default_user, before=agent_ids["gamma_agent"], ascending=True)
|
||||
before_gamma = await server.agent_manager.list_agents_async(actor=default_user, before=agent_ids["gamma_agent"], ascending=True)
|
||||
before_names = [a.name for a in before_gamma]
|
||||
assert "gamma_agent" not in before_names
|
||||
assert "alpha_agent" in before_names
|
||||
@@ -983,12 +998,12 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user):
|
||||
assert before_names == ["alpha_agent", "beta_agent"]
|
||||
|
||||
# After: Get agents after gamma_agent in descending order (should exclude gamma, return beta then alpha)
|
||||
after_gamma_desc = server.agent_manager.list_agents(actor=default_user, after=agent_ids["gamma_agent"], ascending=False)
|
||||
after_gamma_desc = await server.agent_manager.list_agents_async(actor=default_user, after=agent_ids["gamma_agent"], ascending=False)
|
||||
after_names_desc = [a.name for a in after_gamma_desc]
|
||||
assert after_names_desc == ["beta_agent", "alpha_agent"]
|
||||
|
||||
# Before: Get agents before alpha_agent in descending order (should exclude alpha)
|
||||
before_alpha_desc = server.agent_manager.list_agents(actor=default_user, before=agent_ids["alpha_agent"], ascending=False)
|
||||
before_alpha_desc = await server.agent_manager.list_agents_async(actor=default_user, before=agent_ids["alpha_agent"], ascending=False)
|
||||
before_names_desc = [a.name for a in before_alpha_desc]
|
||||
assert before_names_desc == ["gamma_agent", "beta_agent"]
|
||||
|
||||
@@ -1239,74 +1254,85 @@ def test_list_agents_matching_no_tags(server: SyncServer, default_user, agent_wi
|
||||
assert len(agents) == 0 # No agent should match
|
||||
|
||||
|
||||
def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop):
|
||||
"""Test listing agents that have ALL specified tags."""
|
||||
# Create agents with multiple tags
|
||||
server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["test", "production", "gpt4"]), actor=default_user)
|
||||
server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["test", "development", "gpt4"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["test", "production", "gpt4"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["test", "development", "gpt4"]), actor=default_user)
|
||||
|
||||
# Search for agents with all specified tags
|
||||
agents = server.agent_manager.list_agents(actor=default_user, tags=["test", "gpt4"], match_all_tags=True)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["test", "gpt4"], match_all_tags=True)
|
||||
assert len(agents) == 2
|
||||
agent_ids = [a.id for a in agents]
|
||||
assert sarah_agent.id in agent_ids
|
||||
assert charles_agent.id in agent_ids
|
||||
|
||||
# Search for tags that only sarah_agent has
|
||||
agents = server.agent_manager.list_agents(actor=default_user, tags=["test", "production"], match_all_tags=True)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["test", "production"], match_all_tags=True)
|
||||
assert len(agents) == 1
|
||||
assert agents[0].id == sarah_agent.id
|
||||
|
||||
|
||||
def test_list_agents_by_tags_match_any(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_by_tags_match_any(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop):
|
||||
"""Test listing agents that have ANY of the specified tags."""
|
||||
# Create agents with different tags
|
||||
server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user)
|
||||
server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user)
|
||||
|
||||
# Search for agents with any of the specified tags
|
||||
agents = server.agent_manager.list_agents(actor=default_user, tags=["production", "development"], match_all_tags=False)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["production", "development"], match_all_tags=False)
|
||||
assert len(agents) == 2
|
||||
agent_ids = [a.id for a in agents]
|
||||
assert sarah_agent.id in agent_ids
|
||||
assert charles_agent.id in agent_ids
|
||||
|
||||
# Search for tags where only sarah_agent matches
|
||||
agents = server.agent_manager.list_agents(actor=default_user, tags=["production", "nonexistent"], match_all_tags=False)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["production", "nonexistent"], match_all_tags=False)
|
||||
assert len(agents) == 1
|
||||
assert agents[0].id == sarah_agent.id
|
||||
|
||||
|
||||
def test_list_agents_by_tags_no_matches(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_by_tags_no_matches(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop):
|
||||
"""Test listing agents when no tags match."""
|
||||
# Create agents with tags
|
||||
server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user)
|
||||
server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user)
|
||||
|
||||
# Search for nonexistent tags
|
||||
agents = server.agent_manager.list_agents(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=True)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=True)
|
||||
assert len(agents) == 0
|
||||
|
||||
agents = server.agent_manager.list_agents(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=False)
|
||||
agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=False)
|
||||
assert len(agents) == 0
|
||||
|
||||
|
||||
def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop):
|
||||
"""Test combining tag search with other filters."""
|
||||
# Create agents with specific names and tags
|
||||
server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(name="production_agent", tags=["production", "gpt4"]), actor=default_user)
|
||||
server.agent_manager.update_agent(charles_agent.id, UpdateAgent(name="test_agent", tags=["production", "gpt3"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(
|
||||
sarah_agent.id, UpdateAgent(name="production_agent", tags=["production", "gpt4"]), actor=default_user
|
||||
)
|
||||
await server.agent_manager.update_agent_async(
|
||||
charles_agent.id, UpdateAgent(name="test_agent", tags=["production", "gpt3"]), actor=default_user
|
||||
)
|
||||
|
||||
# List agents with specific tag and name pattern
|
||||
agents = server.agent_manager.list_agents(actor=default_user, tags=["production"], match_all_tags=True, name="production_agent")
|
||||
agents = await server.agent_manager.list_agents_async(
|
||||
actor=default_user, tags=["production"], match_all_tags=True, name="production_agent"
|
||||
)
|
||||
assert len(agents) == 1
|
||||
assert agents[0].id == sarah_agent.id
|
||||
|
||||
|
||||
def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization, event_loop):
|
||||
"""Test pagination when listing agents by tags."""
|
||||
# Create first agent
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent1",
|
||||
tags=["pagination_test", "tag1"],
|
||||
@@ -1322,7 +1348,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
time.sleep(CREATE_DELAY_SQLITE) # Ensure distinct created_at timestamps
|
||||
|
||||
# Create second agent
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent2",
|
||||
tags=["pagination_test", "tag2"],
|
||||
@@ -1335,19 +1361,19 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
)
|
||||
|
||||
# Get first page
|
||||
first_page = server.agent_manager.list_agents(actor=default_user, tags=["pagination_test"], match_all_tags=True, limit=1)
|
||||
first_page = await server.agent_manager.list_agents_async(actor=default_user, tags=["pagination_test"], match_all_tags=True, limit=1)
|
||||
assert len(first_page) == 1
|
||||
first_agent_id = first_page[0].id
|
||||
|
||||
# Get second page using cursor
|
||||
second_page = server.agent_manager.list_agents(
|
||||
second_page = await server.agent_manager.list_agents_async(
|
||||
actor=default_user, tags=["pagination_test"], match_all_tags=True, after=first_agent_id, limit=1
|
||||
)
|
||||
assert len(second_page) == 1
|
||||
assert second_page[0].id != first_agent_id
|
||||
|
||||
# Get previous page using before
|
||||
prev_page = server.agent_manager.list_agents(
|
||||
prev_page = await server.agent_manager.list_agents_async(
|
||||
actor=default_user, tags=["pagination_test"], match_all_tags=True, before=second_page[0].id, limit=1
|
||||
)
|
||||
assert len(prev_page) == 1
|
||||
@@ -1360,10 +1386,11 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
assert agent2.id in all_ids
|
||||
|
||||
|
||||
def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization, event_loop):
|
||||
"""Test listing agents with query text filtering and pagination."""
|
||||
# Create test agents with specific names and descriptions
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent1 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="Search Agent One",
|
||||
memory_blocks=[],
|
||||
@@ -1375,7 +1402,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="Search Agent Two",
|
||||
memory_blocks=[],
|
||||
@@ -1387,7 +1414,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent3 = server.agent_manager.create_agent(
|
||||
agent3 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="Different Agent",
|
||||
memory_blocks=[],
|
||||
@@ -1400,32 +1427,32 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
)
|
||||
|
||||
# Test query text filtering
|
||||
search_results = server.agent_manager.list_agents(actor=default_user, query_text="search agent")
|
||||
search_results = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent")
|
||||
assert len(search_results) == 2
|
||||
search_agent_ids = {agent.id for agent in search_results}
|
||||
assert agent1.id in search_agent_ids
|
||||
assert agent2.id in search_agent_ids
|
||||
assert agent3.id not in search_agent_ids
|
||||
|
||||
different_results = server.agent_manager.list_agents(actor=default_user, query_text="different agent")
|
||||
different_results = await server.agent_manager.list_agents_async(actor=default_user, query_text="different agent")
|
||||
assert len(different_results) == 1
|
||||
assert different_results[0].id == agent3.id
|
||||
|
||||
# Test pagination with query text
|
||||
first_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", limit=1)
|
||||
first_page = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent", limit=1)
|
||||
assert len(first_page) == 1
|
||||
first_agent_id = first_page[0].id
|
||||
|
||||
# Get second page using cursor
|
||||
second_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", after=first_agent_id, limit=1)
|
||||
second_page = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent", after=first_agent_id, limit=1)
|
||||
assert len(second_page) == 1
|
||||
assert second_page[0].id != first_agent_id
|
||||
|
||||
# Test before and after
|
||||
all_agents = server.agent_manager.list_agents(actor=default_user, query_text="agent")
|
||||
all_agents = await server.agent_manager.list_agents_async(actor=default_user, query_text="agent")
|
||||
assert len(all_agents) == 3
|
||||
first_agent, second_agent, third_agent = all_agents
|
||||
middle_agent = server.agent_manager.list_agents(
|
||||
middle_agent = await server.agent_manager.list_agents_async(
|
||||
actor=default_user, query_text="search agent", before=third_agent.id, after=first_agent.id
|
||||
)
|
||||
assert len(middle_agent) == 1
|
||||
@@ -1449,7 +1476,7 @@ async def test_reset_messages_no_messages(server: SyncServer, sarah_agent, defau
|
||||
does not fail and clears out message_ids if somehow it's non-empty.
|
||||
"""
|
||||
# Force a weird scenario: Suppose the message_ids field was set non-empty (without actual messages).
|
||||
server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user)
|
||||
updated_agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user)
|
||||
assert updated_agent.message_ids == ["ghost-message-id"]
|
||||
|
||||
@@ -1457,7 +1484,7 @@ async def test_reset_messages_no_messages(server: SyncServer, sarah_agent, defau
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(reset_agent.message_ids) == 1
|
||||
# Double check that physically no messages exist
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1467,7 +1494,7 @@ async def test_reset_messages_default_messages(server: SyncServer, sarah_agent,
|
||||
does not fail and clears out message_ids if somehow it's non-empty.
|
||||
"""
|
||||
# Force a weird scenario: Suppose the message_ids field was set non-empty (without actual messages).
|
||||
server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user)
|
||||
await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user)
|
||||
updated_agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user)
|
||||
assert updated_agent.message_ids == ["ghost-message-id"]
|
||||
|
||||
@@ -1475,7 +1502,7 @@ async def test_reset_messages_default_messages(server: SyncServer, sarah_agent,
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user, add_default_initial_messages=True)
|
||||
assert len(reset_agent.message_ids) == 4
|
||||
# Double check that physically no messages exist
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 4
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1508,7 +1535,7 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a
|
||||
agent_before = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user)
|
||||
# This is 4 because creating the message does not necessarily add it to the in context message ids
|
||||
assert len(agent_before.message_ids) == 4
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 6
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 6
|
||||
|
||||
# 2. Reset all messages
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
@@ -1517,10 +1544,11 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a
|
||||
assert len(reset_agent.message_ids) == 1
|
||||
|
||||
# 4. Verify the messages are physically removed
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
|
||||
|
||||
def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""
|
||||
Test that calling reset_messages multiple times has no adverse effect.
|
||||
"""
|
||||
@@ -1537,15 +1565,16 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use
|
||||
# First reset
|
||||
reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(reset_agent.message_ids) == 1
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
|
||||
# Second reset should do nothing new
|
||||
reset_agent_again = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(reset_agent.message_ids) == 1
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
|
||||
|
||||
def test_modify_letta_message(server: SyncServer, sarah_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_letta_message(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""
|
||||
Test updating a message.
|
||||
"""
|
||||
@@ -1560,32 +1589,32 @@ def test_modify_letta_message(server: SyncServer, sarah_agent, default_user):
|
||||
|
||||
# user message
|
||||
update_user_message = UpdateUserMessage(content="Hello, Sarah!")
|
||||
original_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
|
||||
original_user_message = await server.message_manager.get_message_by_id_async(message_id=user_message.id, actor=default_user)
|
||||
assert original_user_message.content[0].text != update_user_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=user_message.id, letta_message_update=update_user_message, actor=default_user
|
||||
)
|
||||
updated_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
|
||||
updated_user_message = await server.message_manager.get_message_by_id_async(message_id=user_message.id, actor=default_user)
|
||||
assert updated_user_message.content[0].text == update_user_message.content
|
||||
|
||||
# system message
|
||||
update_system_message = UpdateSystemMessage(content="You are a friendly assistant!")
|
||||
original_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
|
||||
original_system_message = await server.message_manager.get_message_by_id_async(message_id=system_message.id, actor=default_user)
|
||||
assert original_system_message.content[0].text != update_system_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=system_message.id, letta_message_update=update_system_message, actor=default_user
|
||||
)
|
||||
updated_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
|
||||
updated_system_message = await server.message_manager.get_message_by_id_async(message_id=system_message.id, actor=default_user)
|
||||
assert updated_system_message.content[0].text == update_system_message.content
|
||||
|
||||
# reasoning message
|
||||
update_reasoning_message = UpdateReasoningMessage(reasoning="I am thinking")
|
||||
original_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
|
||||
original_reasoning_message = await server.message_manager.get_message_by_id_async(message_id=reasoning_message.id, actor=default_user)
|
||||
assert original_reasoning_message.content[0].text != update_reasoning_message.reasoning
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=reasoning_message.id, letta_message_update=update_reasoning_message, actor=default_user
|
||||
)
|
||||
updated_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
|
||||
updated_reasoning_message = await server.message_manager.get_message_by_id_async(message_id=reasoning_message.id, actor=default_user)
|
||||
assert updated_reasoning_message.content[0].text == update_reasoning_message.reasoning
|
||||
|
||||
# assistant message
|
||||
@@ -1597,14 +1626,14 @@ def test_modify_letta_message(server: SyncServer, sarah_agent, default_user):
|
||||
return arguments["message"]
|
||||
|
||||
update_assistant_message = UpdateAssistantMessage(content="I am an agent!")
|
||||
original_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
|
||||
original_assistant_message = await server.message_manager.get_message_by_id_async(message_id=assistant_message.id, actor=default_user)
|
||||
print("ORIGINAL", original_assistant_message.tool_calls)
|
||||
print("MESSAGE", parse_send_message(original_assistant_message.tool_calls[0]))
|
||||
assert parse_send_message(original_assistant_message.tool_calls[0]) != update_assistant_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=assistant_message.id, letta_message_update=update_assistant_message, actor=default_user
|
||||
)
|
||||
updated_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
|
||||
updated_assistant_message = await server.message_manager.get_message_by_id_async(message_id=assistant_message.id, actor=default_user)
|
||||
print("UPDATED", updated_assistant_message.tool_calls)
|
||||
print("MESSAGE", parse_send_message(updated_assistant_message.tool_calls[0]))
|
||||
assert parse_send_message(updated_assistant_message.tool_calls[0]) == update_assistant_message.content
|
||||
@@ -2944,7 +2973,7 @@ def test_checkpoint_creates_history(server: SyncServer, default_user):
|
||||
# Act: checkpoint it
|
||||
block_manager.checkpoint_block(block_id=created_block.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
with db_registry.session() as session:
|
||||
# Get BlockHistory entries for this block
|
||||
history_entries: List[BlockHistory] = session.query(BlockHistory).filter(BlockHistory.block_id == created_block.id).all()
|
||||
assert len(history_entries) == 1, "Exactly one history entry should be created"
|
||||
@@ -2977,7 +3006,7 @@ def test_multiple_checkpoints(server: SyncServer, default_user):
|
||||
# 3) Second checkpoint
|
||||
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
with db_registry.session() as session:
|
||||
history_entries = (
|
||||
session.query(BlockHistory).filter(BlockHistory.block_id == block.id).order_by(BlockHistory.sequence_number.asc()).all()
|
||||
)
|
||||
@@ -3010,7 +3039,7 @@ def test_checkpoint_with_agent_id(server: SyncServer, default_user, sarah_agent)
|
||||
block_manager.checkpoint_block(block_id=block.id, actor=default_user, agent_id=sarah_agent.id)
|
||||
|
||||
# Verify
|
||||
with db_context() as session:
|
||||
with db_registry.session() as session:
|
||||
hist_entry = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).one()
|
||||
assert hist_entry.actor_type == ActorType.LETTA_AGENT
|
||||
assert hist_entry.actor_id == sarah_agent.id
|
||||
@@ -3031,7 +3060,7 @@ def test_checkpoint_with_no_state_change(server: SyncServer, default_user):
|
||||
# 2) checkpoint again (no changes)
|
||||
block_manager.checkpoint_block(block_id=block.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
with db_registry.session() as session:
|
||||
all_hist = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).all()
|
||||
assert len(all_hist) == 2
|
||||
|
||||
@@ -3043,15 +3072,15 @@ def test_checkpoint_concurrency_stale(server: SyncServer, default_user):
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="test_stale_checkpoint", value="hello"), actor=default_user)
|
||||
|
||||
# session1 loads
|
||||
with db_context() as s1:
|
||||
with db_registry.session() as s1:
|
||||
block_s1 = s1.get(Block, block.id) # version=1
|
||||
|
||||
# session2 loads
|
||||
with db_context() as s2:
|
||||
with db_registry.session() as s2:
|
||||
block_s2 = s2.get(Block, block.id) # also version=1
|
||||
|
||||
# session1 checkpoint => version=2
|
||||
with db_context() as s1:
|
||||
with db_registry.session() as s1:
|
||||
block_s1 = s1.merge(block_s1)
|
||||
block_manager.checkpoint_block(
|
||||
block_id=block_s1.id,
|
||||
@@ -3062,7 +3091,7 @@ def test_checkpoint_concurrency_stale(server: SyncServer, default_user):
|
||||
|
||||
# session2 tries to checkpoint => sees old version=1 => stale error
|
||||
with pytest.raises(StaleDataError):
|
||||
with db_context() as s2:
|
||||
with db_registry.session() as s2:
|
||||
block_s2 = s2.merge(block_s2)
|
||||
block_manager.checkpoint_block(
|
||||
block_id=block_s2.id,
|
||||
@@ -3093,7 +3122,7 @@ def test_checkpoint_no_future_states(server: SyncServer, default_user):
|
||||
# 3) Another checkpoint (no changes made) => should become seq=3, not delete anything
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
with db_registry.session() as session:
|
||||
# We expect 3 rows in block_history, none removed
|
||||
history_rows = (
|
||||
session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all()
|
||||
@@ -3190,7 +3219,7 @@ def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default
|
||||
# 5) Checkpoint => new seq=2, removing the old seq=2 and seq=3
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
with db_registry.session() as session:
|
||||
# Let's see which BlockHistory rows remain
|
||||
history_entries = (
|
||||
session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all()
|
||||
@@ -3306,11 +3335,11 @@ def test_undo_concurrency_stale(server: SyncServer, default_user):
|
||||
# Now block is at seq=2
|
||||
|
||||
# session1 preloads the block
|
||||
with db_context() as s1:
|
||||
with db_registry.session() as s1:
|
||||
block_s1 = s1.get(Block, block_v1.id) # version=? let's say 2 in memory
|
||||
|
||||
# session2 also preloads the block
|
||||
with db_context() as s2:
|
||||
with db_registry.session() as s2:
|
||||
block_s2 = s2.get(Block, block_v1.id) # also version=2
|
||||
|
||||
# Session1 -> undo to seq=1
|
||||
@@ -3474,9 +3503,9 @@ def test_redo_concurrency_stale(server: SyncServer, default_user):
|
||||
# but there's a valid row for seq=3 in block_history (the 'v3' state).
|
||||
|
||||
# 5) Simulate concurrency: two sessions each read the block at seq=2
|
||||
with db_context() as s1:
|
||||
with db_registry.session() as s1:
|
||||
block_s1 = s1.get(Block, block.id)
|
||||
with db_context() as s2:
|
||||
with db_registry.session() as s2:
|
||||
block_s2 = s2.get(Block, block.id)
|
||||
|
||||
# 6) Session1 redoes to seq=3 first -> success
|
||||
|
||||
Reference in New Issue
Block a user