feat: asyncify groups operations [LET-4068] (#4254)

feat: asyncify groups operations
This commit is contained in:
cthomas
2025-08-27 12:00:53 -07:00
committed by GitHub
parent 0d1282a09b
commit 3cd746456a
2 changed files with 83 additions and 29 deletions

View File

@@ -17,7 +17,7 @@ router = APIRouter(prefix="/groups", tags=["groups"])
@router.get("/", response_model=List[Group], operation_id="list_groups")
def list_groups(
async def list_groups(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
manager_type: Optional[ManagerType] = Query(None, description="Search groups by manager type"),
@@ -29,8 +29,8 @@ def list_groups(
"""
Fetch all multi-agent groups matching query.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.list_groups(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.group_manager.list_groups_async(
actor=actor,
project_id=project_id,
manager_type=manager_type,
@@ -41,14 +41,15 @@ def list_groups(
@router.get("/count", response_model=int, operation_id="count_groups")
def count_groups(
async def count_groups(
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Get the count of all groups associated with a given user.
"""
return server.group_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.group_manager.size(actor=actor)
@router.get("/{group_id}", response_model=Group, operation_id="retrieve_group")
@@ -69,7 +70,7 @@ async def retrieve_group(
@router.post("/", response_model=Group, operation_id="create_group")
def create_group(
async def create_group(
group: GroupCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
@@ -81,8 +82,8 @@ def create_group(
Create a new multi-agent group with the specified configuration.
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.create_group(group, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.group_manager.create_group_async(group, actor=actor)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -108,7 +109,7 @@ async def modify_group(
@router.delete("/{group_id}", response_model=None, operation_id="delete_group")
def delete_group(
async def delete_group(
group_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
@@ -116,9 +117,9 @@ def delete_group(
"""
Delete a multi-agent group.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
server.group_manager.delete_group(group_id=group_id, actor=actor)
await server.group_manager.delete_group_async(group_id=group_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Group id={group_id} successfully deleted"})
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Group id={group_id} not found for user_id={actor.id}.")
@@ -199,7 +200,7 @@ GroupMessagesResponse = Annotated[
@router.patch("/{group_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_group_message")
def modify_group_message(
async def modify_group_message(
group_id: str,
message_id: str,
request: LettaMessageUpdateUnion = Body(...),
@@ -210,12 +211,12 @@ def modify_group_message(
Update the details of a message associated with an agent.
"""
# TODO: support modifying tool calls/returns
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
@router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages")
def list_group_messages(
async def list_group_messages(
group_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
@@ -229,10 +230,10 @@ def list_group_messages(
"""
Retrieve message history for an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
group = await server.group_manager.retrieve_group_async(group_id=group_id, actor=actor)
if group.manager_agent_id:
return server.get_agent_recall(
return await server.get_agent_recall_async(
user_id=actor.id,
agent_id=group.manager_agent_id,
after=after,
@@ -246,7 +247,7 @@ def list_group_messages(
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
)
else:
return server.group_manager.list_group_messages(
return await server.group_manager.list_group_messages_async(
group_id=group_id,
after=after,
before=before,
@@ -259,7 +260,7 @@ def list_group_messages(
@router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages")
def reset_group_messages(
async def reset_group_messages(
group_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
@@ -267,5 +268,5 @@ def reset_group_messages(
"""
Delete the group messages for all agents that are part of the multi-agent group.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.group_manager.reset_messages(group_id=group_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
await server.group_manager.reset_messages_async(group_id=group_id, actor=actor)

View File

@@ -1,6 +1,6 @@
from typing import List, Optional, Union
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from letta.orm.agent import Agent as AgentModel
@@ -18,9 +18,10 @@ from letta.utils import enforce_types
class GroupManager:
@enforce_types
@trace_method
def list_groups(
async def list_groups_async(
self,
actor: PydanticUser,
project_id: Optional[str] = None,
@@ -29,13 +30,13 @@ class GroupManager:
after: Optional[str] = None,
limit: Optional[int] = 50,
) -> list[PydanticGroup]:
with db_registry.session() as session:
async with db_registry.async_session() as session:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
if manager_type:
filters["manager_type"] = manager_type
groups = GroupModel.list(
groups = await GroupModel.list_async(
db_session=session,
before=before,
after=after,
@@ -274,6 +275,43 @@ class GroupManager:
return messages
@enforce_types
@trace_method
async def list_group_messages_async(
self,
actor: PydanticUser,
group_id: Optional[str] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
use_assistant_message: bool = True,
assistant_message_tool_name: str = "send_message",
assistant_message_tool_kwarg: str = "message",
) -> list[LettaMessage]:
async with db_registry.async_session() as session:
filters = {
"organization_id": actor.organization_id,
"group_id": group_id,
}
messages = await MessageModel.list_async(
db_session=session,
before=before,
after=after,
limit=limit,
**filters,
)
messages = PydanticMessage.to_letta_messages_from_list(
messages=[msg.to_pydantic() for msg in messages],
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
)
# TODO: filter messages to return a clean conversation history
return messages
@enforce_types
@trace_method
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
@@ -288,6 +326,21 @@ class GroupManager:
session.commit()
@enforce_types
@trace_method
async def reset_messages_async(self, group_id: str, actor: PydanticUser) -> None:
async with db_registry.async_session() as session:
# Ensure group is loadable by user
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
# Delete all messages in the group
delete_stmt = delete(MessageModel).where(
MessageModel.organization_id == actor.organization_id, MessageModel.group_id == group_id
)
await session.execute(delete_stmt)
await session.commit()
@enforce_types
@trace_method
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
@@ -342,15 +395,15 @@ class GroupManager:
return prev_last_processed_message_id
@enforce_types
def size(
async def size(
self,
actor: PydanticUser,
) -> int:
"""
Get the total count of groups for the given user.
"""
with db_registry.session() as session:
return GroupModel.size(db_session=session, actor=actor)
async with db_registry.async_session() as session:
return await GroupModel.size_async(db_session=session, actor=actor)
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
if not agent_ids: