From 664bd4739357131972f8dfd2b8cd29badabcab92 Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 23 May 2025 09:07:13 -0700 Subject: [PATCH] feat(asyncify): migrate create block (#2368) --- letta/server/rest_api/routers/v1/blocks.py | 25 +++++------ letta/services/block_manager.py | 51 ++++++++++++++++++++++ 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index bf669f43..d31fd855 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -50,47 +50,46 @@ def count_blocks( @router.post("/", response_model=Block, operation_id="create_block") -def create_block( +async def create_block( create_block: CreateBlock = Body(...), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) block = Block(**create_block.model_dump()) - return server.block_manager.create_or_update_block(actor=actor, block=block) + return await server.block_manager.create_or_update_block_async(actor=actor, block=block) @router.patch("/{block_id}", response_model=Block, operation_id="modify_block") -def modify_block( +async def modify_block( block_id: str, block_update: BlockUpdate = Body(...), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.block_manager.update_block_async(block_id=block_id, block_update=block_update, actor=actor) @router.delete("/{block_id}", response_model=Block, operation_id="delete_block") -def delete_block( +async def delete_block( block_id: str, server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.block_manager.delete_block(block_id=block_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.block_manager.delete_block_async(block_id=block_id, actor=actor) @router.get("/{block_id}", response_model=Block, operation_id="retrieve_block") -def retrieve_block( +async def retrieve_block( block_id: str, server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): - print("call get block", block_id) - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: - block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor) + block = await server.block_manager.get_block_by_id_async(block_id=block_id, actor=actor) if block is None: raise HTTPException(status_code=404, detail="Block not found") return block diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 0795ed7f..fd46e86a 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -38,6 +38,21 @@ class BlockManager: block.create(session, actor=actor) return block.to_pydantic() + @trace_method + @enforce_types + async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: + """Create a new block based on the Block schema.""" + db_block = await self.get_block_by_id_async(block.id, actor) + if db_block: + update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True)) + return await self.update_block_async(block.id, update_data, actor) + else: + async with db_registry.async_session() as session: + data = block.model_dump(to_orm=True, exclude_none=True) + block = BlockModel(**data, organization_id=actor.organization_id) + await block.create_async(session, actor=actor) + return block.to_pydantic() + @trace_method @enforce_types def batch_create_blocks(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]: @@ -78,6 +93,22 @@ class BlockManager: block.update(db_session=session, actor=actor) return block.to_pydantic() + @trace_method + @enforce_types + async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: + """Update a block by its ID with the given BlockUpdate object.""" + # Safety check for block + + async with db_registry.async_session() as session: + block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) + update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + + for key, value in update_data.items(): + setattr(block, key, value) + + await block.update_async(db_session=session, actor=actor) + return block.to_pydantic() + @trace_method @enforce_types def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock: @@ -87,6 +118,15 @@ class BlockManager: block.hard_delete(db_session=session, actor=actor) return block.to_pydantic() + @trace_method + @enforce_types + async def delete_block_async(self, block_id: str, actor: PydanticUser) -> PydanticBlock: + """Delete a block by its ID.""" + async with db_registry.async_session() as session: + block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) + await block.hard_delete_async(db_session=session, actor=actor) + return block.to_pydantic() + @trace_method @enforce_types async def get_blocks_async( @@ -161,6 +201,17 @@ class BlockManager: except NoResultFound: return None + @trace_method + @enforce_types + async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: + """Retrieve a block by its name.""" + async with db_registry.async_session() as session: + try: + block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) + return block.to_pydantic() + except NoResultFound: + return None + @trace_method @enforce_types async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: