From 9a20aa3aae15bc2a339a030d47b8f3b71ad583f7 Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 6 Jun 2025 11:05:09 -0700 Subject: [PATCH] feat: migrate modify group to async (#2670) --- letta/server/rest_api/routers/v1/groups.py | 6 +++--- letta/services/group_manager.py | 10 +++++----- tests/integration_test_sleeptime_agent.py | 6 +++--- tests/integration_test_voice_agent.py | 22 +++++++++++----------- tests/test_multi_agent.py | 4 ++-- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/letta/server/rest_api/routers/v1/groups.py b/letta/server/rest_api/routers/v1/groups.py index 25049536..559b98d5 100644 --- a/letta/server/rest_api/routers/v1/groups.py +++ b/letta/server/rest_api/routers/v1/groups.py @@ -86,7 +86,7 @@ def create_group( @router.patch("/{group_id}", response_model=Group, operation_id="modify_group") -def modify_group( +async def modify_group( group_id: str, group: GroupUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), @@ -97,8 +97,8 @@ def modify_group( Create a new multi-agent group with the specified configuration. """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.group_manager.modify_group(group_id=group_id, group_update=group, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.group_manager.modify_group_async(group_id=group_id, group_update=group, actor=actor) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 57f532fb..d2b0a501 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -152,9 +152,9 @@ class GroupManager: @trace_method @enforce_types - def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup: - with db_registry.session() as session: - group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup: + async with db_registry.async_session() as session: + group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) sleeptime_agent_frequency = None max_message_buffer_length = None @@ -206,11 +206,11 @@ class GroupManager: if group_update.description: group.description = group_update.description if group_update.agent_ids: - self._process_agent_relationship( + await self._process_agent_relationship_async( session=session, group=group, agent_ids=group_update.agent_ids, allow_partial=False, replace=True ) - group.update(session, actor=actor) + await group.update_async(session, actor=actor) return group.to_pydantic() @trace_method diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 6a099d26..8e87643c 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -89,7 +89,7 @@ async def test_sleeptime_group_chat(server, actor): assert "archival_memory_insert" not in main_agent_tools # 2. Override frequency for test - group = server.group_manager.modify_group( + group = await server.group_manager.modify_group_async( group_id=main_agent.multi_agent_group.id, group_update=GroupUpdate( manager_config=SleeptimeManagerUpdate( @@ -203,7 +203,7 @@ async def test_sleeptime_group_chat_v2(server, actor): assert "archival_memory_insert" not in main_agent_tools # 2. Override frequency for test - group = server.group_manager.modify_group( + group = await server.group_manager.modify_group_async( group_id=main_agent.multi_agent_group.id, group_update=GroupUpdate( manager_config=SleeptimeManagerUpdate( @@ -316,7 +316,7 @@ async def test_sleeptime_removes_redundant_information(server, actor): actor=actor, ) - group = server.group_manager.modify_group( + group = await server.group_manager.modify_group_async( group_id=main_agent.multi_agent_group.id, group_update=GroupUpdate( manager_config=SleeptimeManagerUpdate( diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index b3fc86dc..e0da680d 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -343,7 +343,7 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en @pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url): - server.group_manager.modify_group( + await server.group_manager.modify_group_async( group_id=group_id, group_update=GroupUpdate( manager_config=VoiceSleeptimeManagerUpdate( @@ -572,9 +572,9 @@ async def test_init_voice_convo_agent(voice_agent, server, actor, server_url): server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor) -def _modify(group_id, server, actor, max_val, min_val): +async def _modify(group_id, server, actor, max_val, min_val): """Helper to invoke modify_group with voice_sleeptime config.""" - return server.group_manager.modify_group( + return await server.group_manager.modify_group_async( group_id=group_id, group_update=GroupUpdate( manager_config=VoiceSleeptimeManagerUpdate( @@ -587,23 +587,23 @@ def _modify(group_id, server, actor, max_val, min_val): ) -def test_valid_buffer_lengths_above_four(group_id, server, actor): +async def test_valid_buffer_lengths_above_four(group_id, server, actor): # both > 4 and max > min - updated = _modify(group_id, server, actor, max_val=10, min_val=5) + updated = await _modify(group_id, server, actor, max_val=10, min_val=5) assert updated.max_message_buffer_length == 10 assert updated.min_message_buffer_length == 5 -def test_valid_buffer_lengths_only_max(group_id, server, actor): +async def test_valid_buffer_lengths_only_max(group_id, server, actor): # both > 4 and max > min - updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None) + updated = await _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None) assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1 assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH -def test_valid_buffer_lengths_only_min(group_id, server, actor): +async def test_valid_buffer_lengths_only_min(group_id, server, actor): # both > 4 and max > min - updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1) + updated = await _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1) assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1 @@ -624,7 +624,7 @@ def test_valid_buffer_lengths_only_min(group_id, server, actor): (10, 1, "greater than 4"), ], ) -def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part): +async def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part): with pytest.raises(ValueError) as exc: - _modify(group_id, server, actor, max_val, min_val) + await _modify(group_id, server, actor, max_val, min_val) assert err_part in str(exc.value) diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 8389a167..5571b0f0 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -182,7 +182,7 @@ async def test_modify_group_pattern(server, actor, participant_agents, manager_a actor=actor, ) with pytest.raises(ValueError, match="Cannot change group pattern"): - server.group_manager.modify_group( + await server.group_manager.modify_group_async( group_id=group.id, group_update=GroupUpdate( manager_config=DynamicManagerUpdate( @@ -281,7 +281,7 @@ async def test_round_robin(server, actor, participant_agents): assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids) max_turns = 3 - group = server.group_manager.modify_group( + group = await server.group_manager.modify_group_async( group_id=group.id, group_update=GroupUpdate( agent_ids=[agent.id for agent in participant_agents][::-1],