feat: migrate modify group to async (#2670)
This commit is contained in:
@@ -86,7 +86,7 @@ def create_group(
|
||||
|
||||
|
||||
@router.patch("/{group_id}", response_model=Group, operation_id="modify_group")
|
||||
def modify_group(
|
||||
async def modify_group(
|
||||
group_id: str,
|
||||
group: GroupUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -97,8 +97,8 @@ def modify_group(
|
||||
Create a new multi-agent group with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.modify_group(group_id=group_id, group_update=group, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.group_manager.modify_group_async(group_id=group_id, group_update=group, actor=actor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -152,9 +152,9 @@ class GroupManager:
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
|
||||
with db_registry.session() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
sleeptime_agent_frequency = None
|
||||
max_message_buffer_length = None
|
||||
@@ -206,11 +206,11 @@ class GroupManager:
|
||||
if group_update.description:
|
||||
group.description = group_update.description
|
||||
if group_update.agent_ids:
|
||||
self._process_agent_relationship(
|
||||
await self._process_agent_relationship_async(
|
||||
session=session, group=group, agent_ids=group_update.agent_ids, allow_partial=False, replace=True
|
||||
)
|
||||
|
||||
group.update(session, actor=actor)
|
||||
await group.update_async(session, actor=actor)
|
||||
return group.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
|
||||
@@ -89,7 +89,7 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
assert "archival_memory_insert" not in main_agent_tools
|
||||
|
||||
# 2. Override frequency for test
|
||||
group = server.group_manager.modify_group(
|
||||
group = await server.group_manager.modify_group_async(
|
||||
group_id=main_agent.multi_agent_group.id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=SleeptimeManagerUpdate(
|
||||
@@ -203,7 +203,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
|
||||
assert "archival_memory_insert" not in main_agent_tools
|
||||
|
||||
# 2. Override frequency for test
|
||||
group = server.group_manager.modify_group(
|
||||
group = await server.group_manager.modify_group_async(
|
||||
group_id=main_agent.multi_agent_group.id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=SleeptimeManagerUpdate(
|
||||
@@ -316,7 +316,7 @@ async def test_sleeptime_removes_redundant_information(server, actor):
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
group = server.group_manager.modify_group(
|
||||
group = await server.group_manager.modify_group_async(
|
||||
group_id=main_agent.multi_agent_group.id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=SleeptimeManagerUpdate(
|
||||
|
||||
@@ -343,7 +343,7 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url):
|
||||
server.group_manager.modify_group(
|
||||
await server.group_manager.modify_group_async(
|
||||
group_id=group_id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=VoiceSleeptimeManagerUpdate(
|
||||
@@ -572,9 +572,9 @@ async def test_init_voice_convo_agent(voice_agent, server, actor, server_url):
|
||||
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|
||||
|
||||
|
||||
def _modify(group_id, server, actor, max_val, min_val):
|
||||
async def _modify(group_id, server, actor, max_val, min_val):
|
||||
"""Helper to invoke modify_group with voice_sleeptime config."""
|
||||
return server.group_manager.modify_group(
|
||||
return await server.group_manager.modify_group_async(
|
||||
group_id=group_id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=VoiceSleeptimeManagerUpdate(
|
||||
@@ -587,23 +587,23 @@ def _modify(group_id, server, actor, max_val, min_val):
|
||||
)
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_above_four(group_id, server, actor):
|
||||
async def test_valid_buffer_lengths_above_four(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=10, min_val=5)
|
||||
updated = await _modify(group_id, server, actor, max_val=10, min_val=5)
|
||||
assert updated.max_message_buffer_length == 10
|
||||
assert updated.min_message_buffer_length == 5
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_only_max(group_id, server, actor):
|
||||
async def test_valid_buffer_lengths_only_max(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
|
||||
updated = await _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
|
||||
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1
|
||||
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_only_min(group_id, server, actor):
|
||||
async def test_valid_buffer_lengths_only_min(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
|
||||
updated = await _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
|
||||
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH
|
||||
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1
|
||||
|
||||
@@ -624,7 +624,7 @@ def test_valid_buffer_lengths_only_min(group_id, server, actor):
|
||||
(10, 1, "greater than 4"),
|
||||
],
|
||||
)
|
||||
def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
|
||||
async def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_modify(group_id, server, actor, max_val, min_val)
|
||||
await _modify(group_id, server, actor, max_val, min_val)
|
||||
assert err_part in str(exc.value)
|
||||
|
||||
@@ -182,7 +182,7 @@ async def test_modify_group_pattern(server, actor, participant_agents, manager_a
|
||||
actor=actor,
|
||||
)
|
||||
with pytest.raises(ValueError, match="Cannot change group pattern"):
|
||||
server.group_manager.modify_group(
|
||||
await server.group_manager.modify_group_async(
|
||||
group_id=group.id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=DynamicManagerUpdate(
|
||||
@@ -281,7 +281,7 @@ async def test_round_robin(server, actor, participant_agents):
|
||||
assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids)
|
||||
|
||||
max_turns = 3
|
||||
group = server.group_manager.modify_group(
|
||||
group = await server.group_manager.modify_group_async(
|
||||
group_id=group.id,
|
||||
group_update=GroupUpdate(
|
||||
agent_ids=[agent.id for agent in participant_agents][::-1],
|
||||
|
||||
Reference in New Issue
Block a user