feat: migrate modify group to async (#2670)

This commit is contained in:
cthomas
2025-06-06 11:05:09 -07:00
committed by GitHub
parent 1a80267d44
commit 4be038c4b9
5 changed files with 24 additions and 24 deletions

View File

@@ -86,7 +86,7 @@ def create_group(
@router.patch("/{group_id}", response_model=Group, operation_id="modify_group")
def modify_group(
async def modify_group(
group_id: str,
group: GroupUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
@@ -97,8 +97,8 @@ def modify_group(
Create a new multi-agent group with the specified configuration.
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.modify_group(group_id=group_id, group_update=group, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.group_manager.modify_group_async(group_id=group_id, group_update=group, actor=actor)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -152,9 +152,9 @@ class GroupManager:
@trace_method
@enforce_types
def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
with db_registry.session() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
async with db_registry.async_session() as session:
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
sleeptime_agent_frequency = None
max_message_buffer_length = None
@@ -206,11 +206,11 @@ class GroupManager:
if group_update.description:
group.description = group_update.description
if group_update.agent_ids:
self._process_agent_relationship(
await self._process_agent_relationship_async(
session=session, group=group, agent_ids=group_update.agent_ids, allow_partial=False, replace=True
)
group.update(session, actor=actor)
await group.update_async(session, actor=actor)
return group.to_pydantic()
@trace_method

View File

@@ -89,7 +89,7 @@ async def test_sleeptime_group_chat(server, actor):
assert "archival_memory_insert" not in main_agent_tools
# 2. Override frequency for test
group = server.group_manager.modify_group(
group = await server.group_manager.modify_group_async(
group_id=main_agent.multi_agent_group.id,
group_update=GroupUpdate(
manager_config=SleeptimeManagerUpdate(
@@ -203,7 +203,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
assert "archival_memory_insert" not in main_agent_tools
# 2. Override frequency for test
group = server.group_manager.modify_group(
group = await server.group_manager.modify_group_async(
group_id=main_agent.multi_agent_group.id,
group_update=GroupUpdate(
manager_config=SleeptimeManagerUpdate(
@@ -316,7 +316,7 @@ async def test_sleeptime_removes_redundant_information(server, actor):
actor=actor,
)
group = server.group_manager.modify_group(
group = await server.group_manager.modify_group_async(
group_id=main_agent.multi_agent_group.id,
group_update=GroupUpdate(
manager_config=SleeptimeManagerUpdate(

View File

@@ -343,7 +343,7 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url):
server.group_manager.modify_group(
await server.group_manager.modify_group_async(
group_id=group_id,
group_update=GroupUpdate(
manager_config=VoiceSleeptimeManagerUpdate(
@@ -572,9 +572,9 @@ async def test_init_voice_convo_agent(voice_agent, server, actor, server_url):
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
def _modify(group_id, server, actor, max_val, min_val):
async def _modify(group_id, server, actor, max_val, min_val):
"""Helper to invoke modify_group with voice_sleeptime config."""
return server.group_manager.modify_group(
return await server.group_manager.modify_group_async(
group_id=group_id,
group_update=GroupUpdate(
manager_config=VoiceSleeptimeManagerUpdate(
@@ -587,23 +587,23 @@ def _modify(group_id, server, actor, max_val, min_val):
)
def test_valid_buffer_lengths_above_four(group_id, server, actor):
async def test_valid_buffer_lengths_above_four(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=10, min_val=5)
updated = await _modify(group_id, server, actor, max_val=10, min_val=5)
assert updated.max_message_buffer_length == 10
assert updated.min_message_buffer_length == 5
def test_valid_buffer_lengths_only_max(group_id, server, actor):
async def test_valid_buffer_lengths_only_max(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
updated = await _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
def test_valid_buffer_lengths_only_min(group_id, server, actor):
async def test_valid_buffer_lengths_only_min(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
updated = await _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1
@@ -624,7 +624,7 @@ def test_valid_buffer_lengths_only_min(group_id, server, actor):
(10, 1, "greater than 4"),
],
)
def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
async def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
with pytest.raises(ValueError) as exc:
_modify(group_id, server, actor, max_val, min_val)
await _modify(group_id, server, actor, max_val, min_val)
assert err_part in str(exc.value)

View File

@@ -182,7 +182,7 @@ async def test_modify_group_pattern(server, actor, participant_agents, manager_a
actor=actor,
)
with pytest.raises(ValueError, match="Cannot change group pattern"):
server.group_manager.modify_group(
await server.group_manager.modify_group_async(
group_id=group.id,
group_update=GroupUpdate(
manager_config=DynamicManagerUpdate(
@@ -281,7 +281,7 @@ async def test_round_robin(server, actor, participant_agents):
assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids)
max_turns = 3
group = server.group_manager.modify_group(
group = await server.group_manager.modify_group_async(
group_id=group.id,
group_update=GroupUpdate(
agent_ids=[agent.id for agent in participant_agents][::-1],