feat: asyncify groups operations [LET-4068] (#4254)
feat: asyncify groups operations
This commit is contained in:
@@ -17,7 +17,7 @@ router = APIRouter(prefix="/groups", tags=["groups"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Group], operation_id="list_groups")
|
||||
def list_groups(
|
||||
async def list_groups(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
manager_type: Optional[ManagerType] = Query(None, description="Search groups by manager type"),
|
||||
@@ -29,8 +29,8 @@ def list_groups(
|
||||
"""
|
||||
Fetch all multi-agent groups matching query.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.list_groups(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.group_manager.list_groups_async(
|
||||
actor=actor,
|
||||
project_id=project_id,
|
||||
manager_type=manager_type,
|
||||
@@ -41,14 +41,15 @@ def list_groups(
|
||||
|
||||
|
||||
@router.get("/count", response_model=int, operation_id="count_groups")
|
||||
def count_groups(
|
||||
async def count_groups(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get the count of all groups associated with a given user.
|
||||
"""
|
||||
return server.group_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.group_manager.size(actor=actor)
|
||||
|
||||
|
||||
@router.get("/{group_id}", response_model=Group, operation_id="retrieve_group")
|
||||
@@ -69,7 +70,7 @@ async def retrieve_group(
|
||||
|
||||
|
||||
@router.post("/", response_model=Group, operation_id="create_group")
|
||||
def create_group(
|
||||
async def create_group(
|
||||
group: GroupCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -81,8 +82,8 @@ def create_group(
|
||||
Create a new multi-agent group with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(group, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.group_manager.create_group_async(group, actor=actor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -108,7 +109,7 @@ async def modify_group(
|
||||
|
||||
|
||||
@router.delete("/{group_id}", response_model=None, operation_id="delete_group")
|
||||
def delete_group(
|
||||
async def delete_group(
|
||||
group_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -116,9 +117,9 @@ def delete_group(
|
||||
"""
|
||||
Delete a multi-agent group.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
try:
|
||||
server.group_manager.delete_group(group_id=group_id, actor=actor)
|
||||
await server.group_manager.delete_group_async(group_id=group_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Group id={group_id} successfully deleted"})
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Group id={group_id} not found for user_id={actor.id}.")
|
||||
@@ -199,7 +200,7 @@ GroupMessagesResponse = Annotated[
|
||||
|
||||
|
||||
@router.patch("/{group_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_group_message")
|
||||
def modify_group_message(
|
||||
async def modify_group_message(
|
||||
group_id: str,
|
||||
message_id: str,
|
||||
request: LettaMessageUpdateUnion = Body(...),
|
||||
@@ -210,12 +211,12 @@ def modify_group_message(
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
# TODO: support modifying tool calls/returns
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages")
|
||||
def list_group_messages(
|
||||
async def list_group_messages(
|
||||
group_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
|
||||
@@ -229,10 +230,10 @@ def list_group_messages(
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
group = await server.group_manager.retrieve_group_async(group_id=group_id, actor=actor)
|
||||
if group.manager_agent_id:
|
||||
return server.get_agent_recall(
|
||||
return await server.get_agent_recall_async(
|
||||
user_id=actor.id,
|
||||
agent_id=group.manager_agent_id,
|
||||
after=after,
|
||||
@@ -246,7 +247,7 @@ def list_group_messages(
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
else:
|
||||
return server.group_manager.list_group_messages(
|
||||
return await server.group_manager.list_group_messages_async(
|
||||
group_id=group_id,
|
||||
after=after,
|
||||
before=before,
|
||||
@@ -259,7 +260,7 @@ def list_group_messages(
|
||||
|
||||
|
||||
@router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages")
|
||||
def reset_group_messages(
|
||||
async def reset_group_messages(
|
||||
group_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -267,5 +268,5 @@ def reset_group_messages(
|
||||
"""
|
||||
Delete the group messages for all agents that are part of the multi-agent group.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.group_manager.reset_messages(group_id=group_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
await server.group_manager.reset_messages_async(group_id=group_id, actor=actor)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
@@ -18,9 +18,10 @@ from letta.utils import enforce_types
|
||||
|
||||
|
||||
class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_groups(
|
||||
async def list_groups_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
project_id: Optional[str] = None,
|
||||
@@ -29,13 +30,13 @@ class GroupManager:
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
) -> list[PydanticGroup]:
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
if manager_type:
|
||||
filters["manager_type"] = manager_type
|
||||
groups = GroupModel.list(
|
||||
groups = await GroupModel.list_async(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
@@ -274,6 +275,43 @@ class GroupManager:
|
||||
|
||||
return messages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_group_messages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
group_id: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = "send_message",
|
||||
assistant_message_tool_kwarg: str = "message",
|
||||
) -> list[LettaMessage]:
|
||||
async with db_registry.async_session() as session:
|
||||
filters = {
|
||||
"organization_id": actor.organization_id,
|
||||
"group_id": group_id,
|
||||
}
|
||||
messages = await MessageModel.list_async(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
**filters,
|
||||
)
|
||||
|
||||
messages = PydanticMessage.to_letta_messages_from_list(
|
||||
messages=[msg.to_pydantic() for msg in messages],
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
# TODO: filter messages to return a clean conversation history
|
||||
|
||||
return messages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
|
||||
@@ -288,6 +326,21 @@ class GroupManager:
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def reset_messages_async(self, group_id: str, actor: PydanticUser) -> None:
|
||||
async with db_registry.async_session() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
# Delete all messages in the group
|
||||
delete_stmt = delete(MessageModel).where(
|
||||
MessageModel.organization_id == actor.organization_id, MessageModel.group_id == group_id
|
||||
)
|
||||
await session.execute(delete_stmt)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
|
||||
@@ -342,15 +395,15 @@ class GroupManager:
|
||||
return prev_last_processed_message_id
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
async def size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of groups for the given user.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
return GroupModel.size(db_session=session, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
return await GroupModel.size_async(db_session=session, actor=actor)
|
||||
|
||||
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
|
||||
if not agent_ids:
|
||||
|
||||
Reference in New Issue
Block a user