From 00914e5308fd93918ced229df39b01ab081430e5 Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 16 May 2025 00:37:08 -0700 Subject: [PATCH] feat(asyncify): migrate actors(users) endpoints (#2211) --- .../interfaces/openai_streaming_interface.py | 10 ++- letta/orm/sqlalchemy_base.py | 25 +++++++ letta/server/rest_api/routers/v1/agents.py | 16 ++--- letta/server/rest_api/routers/v1/blocks.py | 2 +- letta/server/rest_api/routers/v1/groups.py | 4 +- letta/server/rest_api/routers/v1/messages.py | 10 +-- letta/server/rest_api/routers/v1/runs.py | 4 +- letta/server/rest_api/routers/v1/tools.py | 2 +- letta/server/rest_api/routers/v1/users.py | 18 ++--- letta/server/rest_api/routers/v1/voice.py | 2 +- letta/services/user_manager.py | 70 +++++++++++++++++++ tests/test_managers.py | 28 ++++---- 12 files changed, 143 insertions(+), 48 deletions(-) diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 5b4fade4..168d0521 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -1,16 +1,14 @@ from datetime import datetime, timezone -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import AsyncGenerator, List, Optional from openai import AsyncStream -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, PRE_EXECUTION_MESSAGE_ARG -from letta.interfaces.utils import _format_sse_chunk +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage -from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent +from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall -from letta.schemas.usage import LettaUsageStatistics from letta.server.rest_api.json_parser import OptimisticJSONParser from letta.streaming_utils import JSONInnerThoughtsExtractor diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index dda47c6c..889afb12 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -745,6 +745,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): self.is_deleted = True return self.update(db_session) + @handle_db_timeout + async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase": + """Soft delete a record asynchronously (mark as deleted).""" + logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)") + + if actor: + self._set_created_and_updated_by_fields(actor.id) + + self.is_deleted = True + return await self.update_async(db_session) + @handle_db_timeout def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None: """Permanently removes the record from the database.""" @@ -761,6 +772,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): else: logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted") + @handle_db_timeout + async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None: + """Permanently removes the record from the database asynchronously.""" + logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)") + + async with db_session as session: + try: + await session.delete(self) + await session.commit() + except Exception as e: + await session.rollback() + logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}") + raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}") + @handle_db_timeout def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase": logger.debug(...) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 24226996..fea78e81 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -83,7 +83,7 @@ async def list_agents( """ # Retrieve the actor (user) details - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # Call list_agents directly without unnecessary dict handling return await server.agent_manager.list_agents_async( @@ -163,7 +163,7 @@ async def import_agent_serialized( """ Import a serialized agent file and recreate the agent in the system. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: serialized_data = await file.read() @@ -233,7 +233,7 @@ async def create_agent( Create a new agent with the specified configuration. """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.create_agent_async(agent, actor=actor) except Exception as e: traceback.print_exc() @@ -248,7 +248,7 @@ async def modify_agent( actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Update an existing agent""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.update_agent_async(agent_id=agent_id, request=update_agent, actor=actor) @@ -628,7 +628,7 @@ async def send_message( Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent @@ -688,7 +688,7 @@ async def send_message_streaming( It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. """ request_start_timestamp_ns = get_utc_timestamp_ns() - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent @@ -788,7 +788,7 @@ async def send_message_async( Asynchronously process a user message and return a run object. The actual processing happens in the background, and the status can be checked using the run ID. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # Create a new job run = Run( @@ -842,6 +842,6 @@ async def list_agent_groups( actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Lists the groups for an agent""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) print("in list agents with manager_type", manager_type) return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor) diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 4a9ea8da..c9506906 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -26,7 +26,7 @@ async def list_blocks( server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.block_manager.get_blocks_async( actor=actor, label=label, diff --git a/letta/server/rest_api/routers/v1/groups.py b/letta/server/rest_api/routers/v1/groups.py index 3ed71153..c6c6fb12 100644 --- a/letta/server/rest_api/routers/v1/groups.py +++ b/letta/server/rest_api/routers/v1/groups.py @@ -135,7 +135,7 @@ async def send_group_message( Process a user message and return the group's response. This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) result = await server.send_group_message_to_agent( group_id=group_id, actor=actor, @@ -174,7 +174,7 @@ async def send_group_message_streaming( This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern. It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) result = await server.send_group_message_to_agent( group_id=group_id, actor=actor, diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index fe5e0f91..4d7d3588 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -52,7 +52,7 @@ async def create_messages_batch( detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.", ) - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) batch_job = BatchJob( user_id=actor.id, status=JobStatus.running, @@ -100,7 +100,7 @@ async def retrieve_batch_run( """ Get the status of a batch run. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor) @@ -118,7 +118,7 @@ async def list_batch_runs( List all batch runs. """ # TODO: filter - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH) return [BatchJob.from_job(job) for job in jobs] @@ -150,7 +150,7 @@ async def list_batch_messages( - For subsequent pages, use the ID of the last message from the previous response as the cursor - Results will include messages before/after the cursor based on sort_descending """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # First, verify the batch job exists and the user has access to it try: @@ -177,7 +177,7 @@ async def cancel_batch_run( """ Cancel a batch run. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor) diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index fd7e5131..8a8793a3 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -115,7 +115,7 @@ async def list_run_messages( if order not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'") - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: messages = server.job_manager.get_run_messages( @@ -182,7 +182,7 @@ async def list_run_steps( if order not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'") - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: steps = server.job_manager.get_job_steps( diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 8c9aeac0..be1b9c8c 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -87,7 +87,7 @@ async def list_tools( Get a list of all tools available to agents belonging to the org of the user """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) if name is not None: tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor) return [tool] if tool else [] diff --git a/letta/server/rest_api/routers/v1/users.py b/letta/server/rest_api/routers/v1/users.py index bf2de7ef..4b4bfd91 100644 --- a/letta/server/rest_api/routers/v1/users.py +++ b/letta/server/rest_api/routers/v1/users.py @@ -14,7 +14,7 @@ router = APIRouter(prefix="/users", tags=["users", "admin"]) @router.get("/", tags=["admin"], response_model=List[User], operation_id="list_users") -def list_users( +async def list_users( after: Optional[str] = Query(None), limit: Optional[int] = Query(50), server: "SyncServer" = Depends(get_letta_server), @@ -23,7 +23,7 @@ def list_users( Get a list of all users in the database """ try: - users = server.user_manager.list_users(after=after, limit=limit) + users = await server.user_manager.list_actors_async(after=after, limit=limit) except HTTPException: raise except Exception as e: @@ -32,7 +32,7 @@ def list_users( @router.post("/", tags=["admin"], response_model=User, operation_id="create_user") -def create_user( +async def create_user( request: UserCreate = Body(...), server: "SyncServer" = Depends(get_letta_server), ): @@ -40,33 +40,33 @@ def create_user( Create a new user in the database """ user = User(**request.model_dump()) - user = server.user_manager.create_user(user) + user = await server.user_manager.create_actor_async(user) return user @router.put("/", tags=["admin"], response_model=User, operation_id="update_user") -def update_user( +async def update_user( user: UserUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), ): """ Update a user in the database """ - user = server.user_manager.update_user(user) + user = await server.user_manager.update_actor_async(user) return user @router.delete("/", tags=["admin"], response_model=User, operation_id="delete_user") -def delete_user( +async def delete_user( user_id: str = Query(..., description="The user_id key to be deleted."), server: "SyncServer" = Depends(get_letta_server), ): # TODO make a soft deletion, instead of a hard deletion try: - user = server.user_manager.get_user_by_id(user_id=user_id) + user = await server.user_manager.get_actor_by_id_async(actor_id=user_id) if user is None: raise HTTPException(status_code=404, detail=f"User does not exist") - server.user_manager.delete_user_by_id(user_id=user_id) + await server.user_manager.delete_actor_by_id_async(user_id=user_id) except HTTPException: raise except Exception as e: diff --git a/letta/server/rest_api/routers/v1/voice.py b/letta/server/rest_api/routers/v1/voice.py index 7b3d7efd..5de14e9d 100644 --- a/letta/server/rest_api/routers/v1/voice.py +++ b/letta/server/rest_api/routers/v1/voice.py @@ -38,7 +38,7 @@ async def create_voice_chat_completions( server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id) # Create OpenAI async client client = openai.AsyncClient( diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 9f6a72a5..b1c64100 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -44,6 +44,14 @@ class UserManager: new_user.create(session) return new_user.to_pydantic() + @enforce_types + async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser: + """Create a new user if it doesn't already exist (async version).""" + async with db_registry.async_session() as session: + new_user = UserModel(**pydantic_user.model_dump(to_orm=True)) + await new_user.create_async(session) + return new_user.to_pydantic() + @enforce_types def update_user(self, user_update: UserUpdate) -> PydanticUser: """Update user details.""" @@ -60,6 +68,22 @@ class UserManager: existing_user.update(session) return existing_user.to_pydantic() + @enforce_types + async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser: + """Update user details (async version).""" + async with db_registry.async_session() as session: + # Retrieve the existing user by ID + existing_user = await UserModel.read_async(db_session=session, identifier=user_update.id) + + # Update only the fields that are provided in UserUpdate + update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(existing_user, key, value) + + # Commit the updated user + await existing_user.update_async(session) + return existing_user.to_pydantic() + @enforce_types def delete_user_by_id(self, user_id: str): """Delete a user and their associated records (agents, sources, mappings).""" @@ -70,6 +94,14 @@ class UserManager: session.commit() + @enforce_types + async def delete_actor_by_id_async(self, user_id: str): + """Delete a user and their associated records (agents, sources, mappings) asynchronously.""" + async with db_registry.async_session() as session: + # Delete from user table + user = await UserModel.read_async(db_session=session, identifier=user_id) + await user.hard_delete_async(session) + @enforce_types def get_user_by_id(self, user_id: str) -> PydanticUser: """Fetch a user by ID.""" @@ -77,6 +109,13 @@ class UserManager: user = UserModel.read(db_session=session, identifier=user_id) return user.to_pydantic() + @enforce_types + async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser: + """Fetch a user by ID asynchronously.""" + async with db_registry.async_session() as session: + user = await UserModel.read_async(db_session=session, identifier=actor_id) + return user.to_pydantic() + @enforce_types def get_default_user(self) -> PydanticUser: """Fetch the default user. If it doesn't exist, create it.""" @@ -96,6 +135,26 @@ class UserManager: except NoResultFound: return self.get_default_user() + @enforce_types + async def get_default_actor_async(self) -> PydanticUser: + """Fetch the default user asynchronously. If it doesn't exist, create it.""" + try: + return await self.get_actor_by_id_async(self.DEFAULT_USER_ID) + except NoResultFound: + # Fall back to synchronous version since create_default_user isn't async yet + return self.create_default_user(org_id=self.DEFAULT_ORG_ID) + + @enforce_types + async def get_actor_or_default_async(self, actor_id: Optional[str] = None): + """Fetch the user or default user asynchronously.""" + if not actor_id: + return await self.get_default_actor_async() + + try: + return await self.get_actor_by_id_async(actor_id=actor_id) + except NoResultFound: + return await self.get_default_actor_async() + @enforce_types def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: """List all users with optional pagination.""" @@ -106,3 +165,14 @@ class UserManager: limit=limit, ) return [user.to_pydantic() for user in users] + + @enforce_types + async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: + """List all users with optional pagination (async version).""" + async with db_registry.async_session() as session: + users = await UserModel.list_async( + db_session=session, + after=after, + limit=limit, + ) + return [user.to_pydantic() for user in users] diff --git a/tests/test_managers.py b/tests/test_managers.py index dcb37b5f..043a3d27 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -124,16 +124,16 @@ def default_user(server: SyncServer, default_organization): @pytest.fixture -def other_user(server: SyncServer, default_organization): +async def other_user(server: SyncServer, default_organization): """Fixture to create and return the default user within the default organization.""" - user = server.user_manager.create_user(PydanticUser(name="other", organization_id=default_organization.id)) + user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=default_organization.id)) yield user @pytest.fixture -def other_user_different_org(server: SyncServer, other_organization): +async def other_user_different_org(server: SyncServer, other_organization): """Fixture to create and return the default user within the default organization.""" - user = server.user_manager.create_user(PydanticUser(name="other", organization_id=other_organization.id)) + user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=other_organization.id)) yield user @@ -2120,20 +2120,21 @@ def test_passage_cascade_deletion( # ====================================================================================================================== # User Manager Tests # ====================================================================================================================== -def test_list_users(server: SyncServer): +@pytest.mark.asyncio +async def test_list_users(server: SyncServer, event_loop): # Create default organization org = server.organization_manager.create_default_organization() user_name = "user" - user = server.user_manager.create_user(PydanticUser(name=user_name, organization_id=org.id)) + user = await server.user_manager.create_actor_async(PydanticUser(name=user_name, organization_id=org.id)) - users = server.user_manager.list_users() + users = await server.user_manager.list_actors_async() assert len(users) == 1 assert users[0].name == user_name # Delete it after - server.user_manager.delete_user_by_id(user.id) - assert len(server.user_manager.list_users()) == 0 + await server.user_manager.delete_actor_by_id_async(user.id) + assert len(await server.user_manager.list_actors_async()) == 0 def test_create_default_user(server: SyncServer): @@ -2143,7 +2144,8 @@ def test_create_default_user(server: SyncServer): assert retrieved.name == server.user_manager.DEFAULT_USER_NAME -def test_update_user(server: SyncServer): +@pytest.mark.asyncio +async def test_update_user(server: SyncServer, event_loop): # Create default organization default_org = server.organization_manager.create_default_organization() test_org = server.organization_manager.create_organization(PydanticOrganization(name="test_org")) @@ -2152,16 +2154,16 @@ def test_update_user(server: SyncServer): user_name_b = "b" # Assert it's been created - user = server.user_manager.create_user(PydanticUser(name=user_name_a, organization_id=default_org.id)) + user = await server.user_manager.create_actor_async(PydanticUser(name=user_name_a, organization_id=default_org.id)) assert user.name == user_name_a # Adjust name - user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b)) + user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, name=user_name_b)) assert user.name == user_name_b assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID # Adjust org id - user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id)) + user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, organization_id=test_org.id)) assert user.name == user_name_b assert user.organization_id == test_org.id