diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index d8e71ba7..a4640098 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -185,6 +185,16 @@ class ToolManager: except NoResultFound: return None + @enforce_types + async def get_tool_id_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[str]: + """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" + try: + async with db_registry.async_session() as session: + tool = await ToolModel.read_async(db_session=session, name=tool_name, actor=actor) + return tool.id + except NoResultFound: + return None + @enforce_types async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: """List all tools with optional pagination.""" @@ -280,7 +290,8 @@ class ToolManager: tool.tool_type = updated_tool_type # Save the updated tool to the database - return await tool.update_async(db_session=session, actor=actor).to_pydantic() + tool = await tool.update_async(db_session=session, actor=actor) + return tool.to_pydantic() @enforce_types def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None: diff --git a/tests/test_managers.py b/tests/test_managers.py index 2549fb42..d8aaa256 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -69,7 +69,7 @@ from letta.schemas.tool import ToolCreate, ToolUpdate from letta.schemas.tool_rule import InitToolRule from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate -from letta.server.db import db_context +from letta.server.db import db_registry from letta.server.server import SyncServer from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager @@ -92,14 +92,14 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI")) @pytest.fixture(autouse=True) -def _clear_tables(): - with db_context() as session: +async def _clear_tables(): + async with db_registry.async_session() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues # If this is the block_history table, skip it if table.name == "block_history": continue - session.execute(table.delete()) # Truncate table - session.commit() + await session.execute(table.delete()) # Truncate table + await session.commit() @pytest.fixture @@ -171,7 +171,7 @@ def default_file(server: SyncServer, default_source, default_user, default_organ @pytest.fixture -def print_tool(server: SyncServer, default_user, default_organization): +async def print_tool(server: SyncServer, default_user, default_organization): """Fixture to create a tool with default settings and clean up after the test.""" def print_tool(message: str): @@ -199,7 +199,7 @@ def print_tool(server: SyncServer, default_user, default_organization): tool.json_schema = derived_json_schema tool.name = derived_name - tool = server.tool_manager.create_tool(tool, actor=default_user) + tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user) # Yield the created tool yield tool @@ -237,24 +237,24 @@ def mcp_tool(server, default_user): @pytest.fixture -def default_job(server: SyncServer, default_user): +async def default_job(server: SyncServer, default_user): """Fixture to create and return a default job.""" job_pydantic = PydanticJob( user_id=default_user.id, status=JobStatus.pending, ) - job = server.job_manager.create_job(pydantic_job=job_pydantic, actor=default_user) + job = await server.job_manager.create_job_async(pydantic_job=job_pydantic, actor=default_user) yield job @pytest.fixture -def default_run(server: SyncServer, default_user): +async def default_run(server: SyncServer, default_user): """Fixture to create and return a default job.""" run_pydantic = PydanticRun( user_id=default_user.id, status=JobStatus.pending, ) - run = server.job_manager.create_job(pydantic_job=run_pydantic, actor=default_user) + run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user) yield run @@ -403,7 +403,7 @@ def other_block(server: SyncServer, default_user): @pytest.fixture -def other_tool(server: SyncServer, default_user, default_organization): +async def other_tool(server: SyncServer, default_user, default_organization): def print_other_tool(message: str): """ Args: @@ -428,16 +428,16 @@ def other_tool(server: SyncServer, default_user, default_organization): tool.json_schema = derived_json_schema tool.name = derived_name - tool = server.tool_manager.create_tool(tool, actor=default_user) + tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user) # Yield the created tool yield tool @pytest.fixture -def sarah_agent(server: SyncServer, default_user, default_organization): +async def sarah_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="sarah_agent", memory_blocks=[], @@ -451,9 +451,9 @@ def sarah_agent(server: SyncServer, default_user, default_organization): @pytest.fixture -def charles_agent(server: SyncServer, default_user, default_organization): +async def charles_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="charles_agent", memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], @@ -467,7 +467,7 @@ def charles_agent(server: SyncServer, default_user, default_organization): @pytest.fixture -def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_tool, default_source, default_block): +async def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_tool, default_source, default_block): memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] create_agent_request = CreateAgent( system="test system", @@ -486,7 +486,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too message_buffer_autoclear=True, include_base_tools=False, ) - created_agent = server.agent_manager.create_agent( + created_agent = await server.agent_manager.create_agent_async( create_agent_request, actor=default_user, ) @@ -550,9 +550,9 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent @pytest.fixture -def agent_with_tags(server: SyncServer, default_user): +async def agent_with_tags(server: SyncServer, default_user): """Fixture to create agents with specific tags.""" - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent1", tags=["primary_agent", "benefit_1"], @@ -564,7 +564,7 @@ def agent_with_tags(server: SyncServer, default_user): actor=default_user, ) - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent2", tags=["primary_agent", "benefit_2"], @@ -576,7 +576,7 @@ def agent_with_tags(server: SyncServer, default_user): actor=default_user, ) - agent3 = server.agent_manager.create_agent( + agent3 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent3", tags=["primary_agent", "benefit_1", "benefit_2"], @@ -656,17 +656,18 @@ async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agen comprehensive_agent_checks(get_agent_name, create_agent_request, actor=default_user) # Test list agent - list_agents = server.agent_manager.list_agents(actor=default_user) + list_agents = await server.agent_manager.list_agents_async(actor=default_user) assert len(list_agents) == 1 comprehensive_agent_checks(list_agents[0], create_agent_request, actor=default_user) # Test deleting the agent server.agent_manager.delete_agent(get_agent.id, default_user) - list_agents = server.agent_manager.list_agents(actor=default_user) + list_agents = await server.agent_manager.list_agents_async(actor=default_user) assert len(list_agents) == 0 -def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block): +@pytest.mark.asyncio +async def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block, event_loop): memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] create_agent_request = CreateAgent( system="test system", @@ -679,12 +680,12 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")], include_base_tools=False, ) - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( create_agent_request, actor=default_user, ) - assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 2 - init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) + assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 2 + init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user) # Check that the system appears in the first initial message assert create_agent_request.system in init_messages[0].content[0].text @@ -694,7 +695,8 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use assert create_agent_request.initial_message_sequence[0].content in init_messages[1].content[0].text -def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block): +@pytest.mark.asyncio +async def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block, event_loop): memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] create_agent_request = CreateAgent( system="test system", @@ -706,18 +708,19 @@ def test_create_agent_default_initial_message(server: SyncServer, default_user, description="test_description", include_base_tools=False, ) - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( create_agent_request, actor=default_user, ) - assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 4 - init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) + assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 4 + init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user) # Check that the system appears in the first initial message assert create_agent_request.system in init_messages[0].content[0].text assert create_agent_request.memory_blocks[0].value in init_messages[0].content[0].text -def test_create_agent_with_json_in_system_message(server: SyncServer, default_user, default_block): +@pytest.mark.asyncio +async def test_create_agent_with_json_in_system_message(server: SyncServer, default_user, default_block, event_loop): system_prompt = ( "You are an expert teaching agent with encyclopedic knowledge. " "When you receive a topic, query the external database for more " @@ -734,19 +737,22 @@ def test_create_agent_with_json_in_system_message(server: SyncServer, default_us description="test_description", include_base_tools=False, ) - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( create_agent_request, actor=default_user, ) assert agent_state is not None system_message_id = agent_state.message_ids[0] - system_message = server.message_manager.get_message_by_id(message_id=system_message_id, actor=default_user) + system_message = await server.message_manager.get_message_by_id_async(message_id=system_message_id, actor=default_user) assert system_prompt in system_message.content[0].text assert default_block.value in system_message.content[0].text server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user) -def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user): +@pytest.mark.asyncio +async def test_update_agent( + server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user, event_loop +): agent, _ = comprehensive_test_agent_fixture update_agent_request = UpdateAgent( name="train_agent", @@ -766,7 +772,7 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe ) last_updated_timestamp = agent.updated_at - updated_agent = server.agent_manager.update_agent(agent.id, update_agent_request, actor=default_user) + updated_agent = await server.agent_manager.update_agent_async(agent.id, update_agent_request, actor=default_user) comprehensive_agent_checks(updated_agent, update_agent_request, actor=default_user) assert updated_agent.message_ids == update_agent_request.message_ids assert updated_agent.updated_at > last_updated_timestamp @@ -777,12 +783,13 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe # ====================================================================================================================== -def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): # Create an agent using the comprehensive fixture. created_agent, create_agent_request = comprehensive_test_agent_fixture # List agents using an empty list for select_fields. - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=[]) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=[]) # Assert that the agent is returned and basic fields are present. assert len(agents) >= 1 agent = agents[0] @@ -794,12 +801,13 @@ def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_ assert len(agent.tags) == 0 -def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): # Create an agent using the comprehensive fixture. created_agent, create_agent_request = comprehensive_test_agent_fixture # List agents using an empty list for select_fields. - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=None) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=None) # Assert that the agent is returned and basic fields are present. assert len(agents) >= 1 agent = agents[0] @@ -811,12 +819,13 @@ def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_a assert len(agent.tags) > 0 -def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): created_agent, create_agent_request = comprehensive_test_agent_fixture # Choose a subset of valid relationship fields. valid_fields = ["tools", "tags"] - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=valid_fields) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=valid_fields) assert len(agents) >= 1 agent = agents[0] # Depending on your to_pydantic() implementation, @@ -827,13 +836,14 @@ def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_te assert not agent.memory.blocks -def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): created_agent, create_agent_request = comprehensive_test_agent_fixture # Provide field names that are not recognized. invalid_fields = ["foobar", "nonexistent_field"] # The expectation is that these fields are simply ignored. - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=invalid_fields) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=invalid_fields) assert len(agents) >= 1 agent = agents[0] # Verify that standard fields are still present.c @@ -841,12 +851,13 @@ def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_tes assert agent.name is not None -def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): created_agent, create_agent_request = comprehensive_test_agent_fixture # Provide duplicate valid field names. duplicate_fields = ["tools", "tools", "tags", "tags"] - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=duplicate_fields) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=duplicate_fields) assert len(agents) >= 1 agent = agents[0] # Verify that the agent pydantic representation includes the relationships. @@ -855,12 +866,13 @@ def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_ assert isinstance(agent.tags, list) -def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): created_agent, create_agent_request = comprehensive_test_agent_fixture # Mix valid fields with an invalid one. mixed_fields = ["tools", "invalid_field"] - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=mixed_fields) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=mixed_fields) assert len(agents) >= 1 agent = agents[0] # Valid fields should be loaded and accessible. @@ -870,9 +882,10 @@ def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_ assert not hasattr(agent, "invalid_field") -def test_list_agents_ascending(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_agents_ascending(server: SyncServer, default_user, event_loop): # Create two agents with known names - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent_oldest", llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -886,7 +899,7 @@ def test_list_agents_ascending(server: SyncServer, default_user): if USING_SQLITE: time.sleep(CREATE_DELAY_SQLITE) - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent_newest", llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -897,14 +910,15 @@ def test_list_agents_ascending(server: SyncServer, default_user): actor=default_user, ) - agents = server.agent_manager.list_agents(actor=default_user, ascending=True) + agents = await server.agent_manager.list_agents_async(actor=default_user, ascending=True) names = [agent.name for agent in agents] assert names.index("agent_oldest") < names.index("agent_newest") -def test_list_agents_descending(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_agents_descending(server: SyncServer, default_user, event_loop): # Create two agents with known names - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent_oldest", llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -918,7 +932,7 @@ def test_list_agents_descending(server: SyncServer, default_user): if USING_SQLITE: time.sleep(CREATE_DELAY_SQLITE) - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent_newest", llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -929,18 +943,19 @@ def test_list_agents_descending(server: SyncServer, default_user): actor=default_user, ) - agents = server.agent_manager.list_agents(actor=default_user, ascending=False) + agents = await server.agent_manager.list_agents_async(actor=default_user, ascending=False) names = [agent.name for agent in agents] assert names.index("agent_newest") < names.index("agent_oldest") -def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_agents_ordering_and_pagination(server: SyncServer, default_user, event_loop): names = ["alpha_agent", "beta_agent", "gamma_agent"] created_agents = [] # Create agents in known order for name in names: - agent = server.agent_manager.create_agent( + agent = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name=name, memory_blocks=[], @@ -957,17 +972,17 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): agent_ids = {agent.name: agent.id for agent in created_agents} # Ascending (oldest to newest) - agents_asc = server.agent_manager.list_agents(actor=default_user, ascending=True) + agents_asc = await server.agent_manager.list_agents_async(actor=default_user, ascending=True) asc_names = [agent.name for agent in agents_asc] assert asc_names.index("alpha_agent") < asc_names.index("beta_agent") < asc_names.index("gamma_agent") # Descending (newest to oldest) - agents_desc = server.agent_manager.list_agents(actor=default_user, ascending=False) + agents_desc = await server.agent_manager.list_agents_async(actor=default_user, ascending=False) desc_names = [agent.name for agent in agents_desc] assert desc_names.index("gamma_agent") < desc_names.index("beta_agent") < desc_names.index("alpha_agent") # After: Get agents after alpha_agent in ascending order (should exclude alpha) - after_alpha = server.agent_manager.list_agents(actor=default_user, after=agent_ids["alpha_agent"], ascending=True) + after_alpha = await server.agent_manager.list_agents_async(actor=default_user, after=agent_ids["alpha_agent"], ascending=True) after_names = [a.name for a in after_alpha] assert "alpha_agent" not in after_names assert "beta_agent" in after_names @@ -975,7 +990,7 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): assert after_names == ["beta_agent", "gamma_agent"] # Before: Get agents before gamma_agent in ascending order (should exclude gamma) - before_gamma = server.agent_manager.list_agents(actor=default_user, before=agent_ids["gamma_agent"], ascending=True) + before_gamma = await server.agent_manager.list_agents_async(actor=default_user, before=agent_ids["gamma_agent"], ascending=True) before_names = [a.name for a in before_gamma] assert "gamma_agent" not in before_names assert "alpha_agent" in before_names @@ -983,12 +998,12 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): assert before_names == ["alpha_agent", "beta_agent"] # After: Get agents after gamma_agent in descending order (should exclude gamma, return beta then alpha) - after_gamma_desc = server.agent_manager.list_agents(actor=default_user, after=agent_ids["gamma_agent"], ascending=False) + after_gamma_desc = await server.agent_manager.list_agents_async(actor=default_user, after=agent_ids["gamma_agent"], ascending=False) after_names_desc = [a.name for a in after_gamma_desc] assert after_names_desc == ["beta_agent", "alpha_agent"] # Before: Get agents before alpha_agent in descending order (should exclude alpha) - before_alpha_desc = server.agent_manager.list_agents(actor=default_user, before=agent_ids["alpha_agent"], ascending=False) + before_alpha_desc = await server.agent_manager.list_agents_async(actor=default_user, before=agent_ids["alpha_agent"], ascending=False) before_names_desc = [a.name for a in before_alpha_desc] assert before_names_desc == ["gamma_agent", "beta_agent"] @@ -1239,74 +1254,85 @@ def test_list_agents_matching_no_tags(server: SyncServer, default_user, agent_wi assert len(agents) == 0 # No agent should match -def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): """Test listing agents that have ALL specified tags.""" # Create agents with multiple tags - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["test", "production", "gpt4"]), actor=default_user) - server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["test", "development", "gpt4"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["test", "production", "gpt4"]), actor=default_user) + await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["test", "development", "gpt4"]), actor=default_user) # Search for agents with all specified tags - agents = server.agent_manager.list_agents(actor=default_user, tags=["test", "gpt4"], match_all_tags=True) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["test", "gpt4"], match_all_tags=True) assert len(agents) == 2 agent_ids = [a.id for a in agents] assert sarah_agent.id in agent_ids assert charles_agent.id in agent_ids # Search for tags that only sarah_agent has - agents = server.agent_manager.list_agents(actor=default_user, tags=["test", "production"], match_all_tags=True) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["test", "production"], match_all_tags=True) assert len(agents) == 1 assert agents[0].id == sarah_agent.id -def test_list_agents_by_tags_match_any(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_list_agents_by_tags_match_any(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): """Test listing agents that have ANY of the specified tags.""" # Create agents with different tags - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) - server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) + await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) # Search for agents with any of the specified tags - agents = server.agent_manager.list_agents(actor=default_user, tags=["production", "development"], match_all_tags=False) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["production", "development"], match_all_tags=False) assert len(agents) == 2 agent_ids = [a.id for a in agents] assert sarah_agent.id in agent_ids assert charles_agent.id in agent_ids # Search for tags where only sarah_agent matches - agents = server.agent_manager.list_agents(actor=default_user, tags=["production", "nonexistent"], match_all_tags=False) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["production", "nonexistent"], match_all_tags=False) assert len(agents) == 1 assert agents[0].id == sarah_agent.id -def test_list_agents_by_tags_no_matches(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_list_agents_by_tags_no_matches(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): """Test listing agents when no tags match.""" # Create agents with tags - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) - server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) + await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) # Search for nonexistent tags - agents = server.agent_manager.list_agents(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=True) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=True) assert len(agents) == 0 - agents = server.agent_manager.list_agents(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=False) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=False) assert len(agents) == 0 -def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): """Test combining tag search with other filters.""" # Create agents with specific names and tags - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(name="production_agent", tags=["production", "gpt4"]), actor=default_user) - server.agent_manager.update_agent(charles_agent.id, UpdateAgent(name="test_agent", tags=["production", "gpt3"]), actor=default_user) + await server.agent_manager.update_agent_async( + sarah_agent.id, UpdateAgent(name="production_agent", tags=["production", "gpt4"]), actor=default_user + ) + await server.agent_manager.update_agent_async( + charles_agent.id, UpdateAgent(name="test_agent", tags=["production", "gpt3"]), actor=default_user + ) # List agents with specific tag and name pattern - agents = server.agent_manager.list_agents(actor=default_user, tags=["production"], match_all_tags=True, name="production_agent") + agents = await server.agent_manager.list_agents_async( + actor=default_user, tags=["production"], match_all_tags=True, name="production_agent" + ) assert len(agents) == 1 assert agents[0].id == sarah_agent.id -def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization): +@pytest.mark.asyncio +async def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization, event_loop): """Test pagination when listing agents by tags.""" # Create first agent - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent1", tags=["pagination_test", "tag1"], @@ -1322,7 +1348,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul time.sleep(CREATE_DELAY_SQLITE) # Ensure distinct created_at timestamps # Create second agent - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent2", tags=["pagination_test", "tag2"], @@ -1335,19 +1361,19 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul ) # Get first page - first_page = server.agent_manager.list_agents(actor=default_user, tags=["pagination_test"], match_all_tags=True, limit=1) + first_page = await server.agent_manager.list_agents_async(actor=default_user, tags=["pagination_test"], match_all_tags=True, limit=1) assert len(first_page) == 1 first_agent_id = first_page[0].id # Get second page using cursor - second_page = server.agent_manager.list_agents( + second_page = await server.agent_manager.list_agents_async( actor=default_user, tags=["pagination_test"], match_all_tags=True, after=first_agent_id, limit=1 ) assert len(second_page) == 1 assert second_page[0].id != first_agent_id # Get previous page using before - prev_page = server.agent_manager.list_agents( + prev_page = await server.agent_manager.list_agents_async( actor=default_user, tags=["pagination_test"], match_all_tags=True, before=second_page[0].id, limit=1 ) assert len(prev_page) == 1 @@ -1360,10 +1386,11 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul assert agent2.id in all_ids -def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization): +@pytest.mark.asyncio +async def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization, event_loop): """Test listing agents with query text filtering and pagination.""" # Create test agents with specific names and descriptions - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="Search Agent One", memory_blocks=[], @@ -1375,7 +1402,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def actor=default_user, ) - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="Search Agent Two", memory_blocks=[], @@ -1387,7 +1414,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def actor=default_user, ) - agent3 = server.agent_manager.create_agent( + agent3 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="Different Agent", memory_blocks=[], @@ -1400,32 +1427,32 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def ) # Test query text filtering - search_results = server.agent_manager.list_agents(actor=default_user, query_text="search agent") + search_results = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent") assert len(search_results) == 2 search_agent_ids = {agent.id for agent in search_results} assert agent1.id in search_agent_ids assert agent2.id in search_agent_ids assert agent3.id not in search_agent_ids - different_results = server.agent_manager.list_agents(actor=default_user, query_text="different agent") + different_results = await server.agent_manager.list_agents_async(actor=default_user, query_text="different agent") assert len(different_results) == 1 assert different_results[0].id == agent3.id # Test pagination with query text - first_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", limit=1) + first_page = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent", limit=1) assert len(first_page) == 1 first_agent_id = first_page[0].id # Get second page using cursor - second_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", after=first_agent_id, limit=1) + second_page = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent", after=first_agent_id, limit=1) assert len(second_page) == 1 assert second_page[0].id != first_agent_id # Test before and after - all_agents = server.agent_manager.list_agents(actor=default_user, query_text="agent") + all_agents = await server.agent_manager.list_agents_async(actor=default_user, query_text="agent") assert len(all_agents) == 3 first_agent, second_agent, third_agent = all_agents - middle_agent = server.agent_manager.list_agents( + middle_agent = await server.agent_manager.list_agents_async( actor=default_user, query_text="search agent", before=third_agent.id, after=first_agent.id ) assert len(middle_agent) == 1 @@ -1449,7 +1476,7 @@ async def test_reset_messages_no_messages(server: SyncServer, sarah_agent, defau does not fail and clears out message_ids if somehow it's non-empty. """ # Force a weird scenario: Suppose the message_ids field was set non-empty (without actual messages). - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user) updated_agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user) assert updated_agent.message_ids == ["ghost-message-id"] @@ -1457,7 +1484,7 @@ async def test_reset_messages_no_messages(server: SyncServer, sarah_agent, defau reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 # Double check that physically no messages exist - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 @pytest.mark.asyncio @@ -1467,7 +1494,7 @@ async def test_reset_messages_default_messages(server: SyncServer, sarah_agent, does not fail and clears out message_ids if somehow it's non-empty. """ # Force a weird scenario: Suppose the message_ids field was set non-empty (without actual messages). - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user) updated_agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user) assert updated_agent.message_ids == ["ghost-message-id"] @@ -1475,7 +1502,7 @@ async def test_reset_messages_default_messages(server: SyncServer, sarah_agent, reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user, add_default_initial_messages=True) assert len(reset_agent.message_ids) == 4 # Double check that physically no messages exist - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 4 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 4 @pytest.mark.asyncio @@ -1508,7 +1535,7 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a agent_before = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user) # This is 4 because creating the message does not necessarily add it to the in context message ids assert len(agent_before.message_ids) == 4 - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 6 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 6 # 2. Reset all messages reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) @@ -1517,10 +1544,11 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a assert len(reset_agent.message_ids) == 1 # 4. Verify the messages are physically removed - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 -def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_user): +@pytest.mark.asyncio +async def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_user, event_loop): """ Test that calling reset_messages multiple times has no adverse effect. """ @@ -1537,15 +1565,16 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use # First reset reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 # Second reset should do nothing new reset_agent_again = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 -def test_modify_letta_message(server: SyncServer, sarah_agent, default_user): +@pytest.mark.asyncio +async def test_modify_letta_message(server: SyncServer, sarah_agent, default_user, event_loop): """ Test updating a message. """ @@ -1560,32 +1589,32 @@ def test_modify_letta_message(server: SyncServer, sarah_agent, default_user): # user message update_user_message = UpdateUserMessage(content="Hello, Sarah!") - original_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user) + original_user_message = await server.message_manager.get_message_by_id_async(message_id=user_message.id, actor=default_user) assert original_user_message.content[0].text != update_user_message.content server.message_manager.update_message_by_letta_message( message_id=user_message.id, letta_message_update=update_user_message, actor=default_user ) - updated_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user) + updated_user_message = await server.message_manager.get_message_by_id_async(message_id=user_message.id, actor=default_user) assert updated_user_message.content[0].text == update_user_message.content # system message update_system_message = UpdateSystemMessage(content="You are a friendly assistant!") - original_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user) + original_system_message = await server.message_manager.get_message_by_id_async(message_id=system_message.id, actor=default_user) assert original_system_message.content[0].text != update_system_message.content server.message_manager.update_message_by_letta_message( message_id=system_message.id, letta_message_update=update_system_message, actor=default_user ) - updated_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user) + updated_system_message = await server.message_manager.get_message_by_id_async(message_id=system_message.id, actor=default_user) assert updated_system_message.content[0].text == update_system_message.content # reasoning message update_reasoning_message = UpdateReasoningMessage(reasoning="I am thinking") - original_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user) + original_reasoning_message = await server.message_manager.get_message_by_id_async(message_id=reasoning_message.id, actor=default_user) assert original_reasoning_message.content[0].text != update_reasoning_message.reasoning server.message_manager.update_message_by_letta_message( message_id=reasoning_message.id, letta_message_update=update_reasoning_message, actor=default_user ) - updated_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user) + updated_reasoning_message = await server.message_manager.get_message_by_id_async(message_id=reasoning_message.id, actor=default_user) assert updated_reasoning_message.content[0].text == update_reasoning_message.reasoning # assistant message @@ -1597,14 +1626,14 @@ def test_modify_letta_message(server: SyncServer, sarah_agent, default_user): return arguments["message"] update_assistant_message = UpdateAssistantMessage(content="I am an agent!") - original_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user) + original_assistant_message = await server.message_manager.get_message_by_id_async(message_id=assistant_message.id, actor=default_user) print("ORIGINAL", original_assistant_message.tool_calls) print("MESSAGE", parse_send_message(original_assistant_message.tool_calls[0])) assert parse_send_message(original_assistant_message.tool_calls[0]) != update_assistant_message.content server.message_manager.update_message_by_letta_message( message_id=assistant_message.id, letta_message_update=update_assistant_message, actor=default_user ) - updated_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user) + updated_assistant_message = await server.message_manager.get_message_by_id_async(message_id=assistant_message.id, actor=default_user) print("UPDATED", updated_assistant_message.tool_calls) print("MESSAGE", parse_send_message(updated_assistant_message.tool_calls[0])) assert parse_send_message(updated_assistant_message.tool_calls[0]) == update_assistant_message.content @@ -2944,7 +2973,7 @@ def test_checkpoint_creates_history(server: SyncServer, default_user): # Act: checkpoint it block_manager.checkpoint_block(block_id=created_block.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: # Get BlockHistory entries for this block history_entries: List[BlockHistory] = session.query(BlockHistory).filter(BlockHistory.block_id == created_block.id).all() assert len(history_entries) == 1, "Exactly one history entry should be created" @@ -2977,7 +3006,7 @@ def test_multiple_checkpoints(server: SyncServer, default_user): # 3) Second checkpoint block_manager.checkpoint_block(block_id=block.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: history_entries = ( session.query(BlockHistory).filter(BlockHistory.block_id == block.id).order_by(BlockHistory.sequence_number.asc()).all() ) @@ -3010,7 +3039,7 @@ def test_checkpoint_with_agent_id(server: SyncServer, default_user, sarah_agent) block_manager.checkpoint_block(block_id=block.id, actor=default_user, agent_id=sarah_agent.id) # Verify - with db_context() as session: + with db_registry.session() as session: hist_entry = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).one() assert hist_entry.actor_type == ActorType.LETTA_AGENT assert hist_entry.actor_id == sarah_agent.id @@ -3031,7 +3060,7 @@ def test_checkpoint_with_no_state_change(server: SyncServer, default_user): # 2) checkpoint again (no changes) block_manager.checkpoint_block(block_id=block.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: all_hist = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).all() assert len(all_hist) == 2 @@ -3043,15 +3072,15 @@ def test_checkpoint_concurrency_stale(server: SyncServer, default_user): block = block_manager.create_or_update_block(PydanticBlock(label="test_stale_checkpoint", value="hello"), actor=default_user) # session1 loads - with db_context() as s1: + with db_registry.session() as s1: block_s1 = s1.get(Block, block.id) # version=1 # session2 loads - with db_context() as s2: + with db_registry.session() as s2: block_s2 = s2.get(Block, block.id) # also version=1 # session1 checkpoint => version=2 - with db_context() as s1: + with db_registry.session() as s1: block_s1 = s1.merge(block_s1) block_manager.checkpoint_block( block_id=block_s1.id, @@ -3062,7 +3091,7 @@ def test_checkpoint_concurrency_stale(server: SyncServer, default_user): # session2 tries to checkpoint => sees old version=1 => stale error with pytest.raises(StaleDataError): - with db_context() as s2: + with db_registry.session() as s2: block_s2 = s2.merge(block_s2) block_manager.checkpoint_block( block_id=block_s2.id, @@ -3093,7 +3122,7 @@ def test_checkpoint_no_future_states(server: SyncServer, default_user): # 3) Another checkpoint (no changes made) => should become seq=3, not delete anything block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: # We expect 3 rows in block_history, none removed history_rows = ( session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all() @@ -3190,7 +3219,7 @@ def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default # 5) Checkpoint => new seq=2, removing the old seq=2 and seq=3 block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: # Let's see which BlockHistory rows remain history_entries = ( session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all() @@ -3306,11 +3335,11 @@ def test_undo_concurrency_stale(server: SyncServer, default_user): # Now block is at seq=2 # session1 preloads the block - with db_context() as s1: + with db_registry.session() as s1: block_s1 = s1.get(Block, block_v1.id) # version=? let's say 2 in memory # session2 also preloads the block - with db_context() as s2: + with db_registry.session() as s2: block_s2 = s2.get(Block, block_v1.id) # also version=2 # Session1 -> undo to seq=1 @@ -3474,9 +3503,9 @@ def test_redo_concurrency_stale(server: SyncServer, default_user): # but there's a valid row for seq=3 in block_history (the 'v3' state). # 5) Simulate concurrency: two sessions each read the block at seq=2 - with db_context() as s1: + with db_registry.session() as s1: block_s1 = s1.get(Block, block.id) - with db_context() as s2: + with db_registry.session() as s2: block_s2 = s2.get(Block, block.id) # 6) Session1 redoes to seq=3 first -> success