diff --git a/letta/server/rest_api/routers/v1/groups.py b/letta/server/rest_api/routers/v1/groups.py index 14f95115..8359b163 100644 --- a/letta/server/rest_api/routers/v1/groups.py +++ b/letta/server/rest_api/routers/v1/groups.py @@ -17,7 +17,7 @@ router = APIRouter(prefix="/groups", tags=["groups"]) @router.get("/", response_model=List[Group], operation_id="list_groups") -def list_groups( +async def list_groups( server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), manager_type: Optional[ManagerType] = Query(None, description="Search groups by manager type"), @@ -29,8 +29,8 @@ def list_groups( """ Fetch all multi-agent groups matching query. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.group_manager.list_groups( + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.group_manager.list_groups_async( actor=actor, project_id=project_id, manager_type=manager_type, @@ -41,14 +41,15 @@ def list_groups( @router.get("/count", response_model=int, operation_id="count_groups") -def count_groups( +async def count_groups( server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Get the count of all groups associated with a given user. """ - return server.group_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.group_manager.size(actor=actor) @router.get("/{group_id}", response_model=Group, operation_id="retrieve_group") @@ -69,7 +70,7 @@ async def retrieve_group( @router.post("/", response_model=Group, operation_id="create_group") -def create_group( +async def create_group( group: GroupCreate = Body(...), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -81,8 +82,8 @@ def create_group( Create a new multi-agent group with the specified configuration. """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.group_manager.create_group(group, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.group_manager.create_group_async(group, actor=actor) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -108,7 +109,7 @@ async def modify_group( @router.delete("/{group_id}", response_model=None, operation_id="delete_group") -def delete_group( +async def delete_group( group_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -116,9 +117,9 @@ def delete_group( """ Delete a multi-agent group. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: - server.group_manager.delete_group(group_id=group_id, actor=actor) + await server.group_manager.delete_group_async(group_id=group_id, actor=actor) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Group id={group_id} successfully deleted"}) except NoResultFound: raise HTTPException(status_code=404, detail=f"Group id={group_id} not found for user_id={actor.id}.") @@ -199,7 +200,7 @@ GroupMessagesResponse = Annotated[ @router.patch("/{group_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_group_message") -def modify_group_message( +async def modify_group_message( group_id: str, message_id: str, request: LettaMessageUpdateUnion = Body(...), @@ -210,12 +211,12 @@ def modify_group_message( Update the details of a message associated with an agent. """ # TODO: support modifying tool calls/returns - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor) @router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages") -def list_group_messages( +async def list_group_messages( group_id: str, server: "SyncServer" = Depends(get_letta_server), after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."), @@ -229,10 +230,10 @@ def list_group_messages( """ Retrieve message history for an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - group = server.group_manager.retrieve_group(group_id=group_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + group = await server.group_manager.retrieve_group_async(group_id=group_id, actor=actor) if group.manager_agent_id: - return server.get_agent_recall( + return await server.get_agent_recall_async( user_id=actor.id, agent_id=group.manager_agent_id, after=after, @@ -246,7 +247,7 @@ def list_group_messages( assistant_message_tool_kwarg=assistant_message_tool_kwarg, ) else: - return server.group_manager.list_group_messages( + return await server.group_manager.list_group_messages_async( group_id=group_id, after=after, before=before, @@ -259,7 +260,7 @@ def list_group_messages( @router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages") -def reset_group_messages( +async def reset_group_messages( group_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -267,5 +268,5 @@ def reset_group_messages( """ Delete the group messages for all agents that are part of the multi-agent group. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - server.group_manager.reset_messages(group_id=group_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + await server.group_manager.reset_messages_async(group_id=group_id, actor=actor) diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 7dfffe15..fb393add 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -1,6 +1,6 @@ from typing import List, Optional, Union -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.orm import Session from letta.orm.agent import Agent as AgentModel @@ -18,9 +18,10 @@ from letta.utils import enforce_types class GroupManager: + @enforce_types @trace_method - def list_groups( + async def list_groups_async( self, actor: PydanticUser, project_id: Optional[str] = None, @@ -29,13 +30,13 @@ class GroupManager: after: Optional[str] = None, limit: Optional[int] = 50, ) -> list[PydanticGroup]: - with db_registry.session() as session: + async with db_registry.async_session() as session: filters = {"organization_id": actor.organization_id} if project_id: filters["project_id"] = project_id if manager_type: filters["manager_type"] = manager_type - groups = GroupModel.list( + groups = await GroupModel.list_async( db_session=session, before=before, after=after, @@ -274,6 +275,43 @@ class GroupManager: return messages + @enforce_types + @trace_method + async def list_group_messages_async( + self, + actor: PydanticUser, + group_id: Optional[str] = None, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + use_assistant_message: bool = True, + assistant_message_tool_name: str = "send_message", + assistant_message_tool_kwarg: str = "message", + ) -> list[LettaMessage]: + async with db_registry.async_session() as session: + filters = { + "organization_id": actor.organization_id, + "group_id": group_id, + } + messages = await MessageModel.list_async( + db_session=session, + before=before, + after=after, + limit=limit, + **filters, + ) + + messages = PydanticMessage.to_letta_messages_from_list( + messages=[msg.to_pydantic() for msg in messages], + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + ) + + # TODO: filter messages to return a clean conversation history + + return messages + @enforce_types @trace_method def reset_messages(self, group_id: str, actor: PydanticUser) -> None: @@ -288,6 +326,21 @@ class GroupManager: session.commit() + @enforce_types + @trace_method + async def reset_messages_async(self, group_id: str, actor: PydanticUser) -> None: + async with db_registry.async_session() as session: + # Ensure group is loadable by user + group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) + + # Delete all messages in the group + delete_stmt = delete(MessageModel).where( + MessageModel.organization_id == actor.organization_id, MessageModel.group_id == group_id + ) + await session.execute(delete_stmt) + + await session.commit() + @enforce_types @trace_method def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int: @@ -342,15 +395,15 @@ class GroupManager: return prev_last_processed_message_id @enforce_types - def size( + async def size( self, actor: PydanticUser, ) -> int: """ Get the total count of groups for the given user. """ - with db_registry.session() as session: - return GroupModel.size(db_session=session, actor=actor) + async with db_registry.async_session() as session: + return await GroupModel.size_async(db_session=session, actor=actor) def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True): if not agent_ids: