feat(asyncify): migrate create block (#2368)
This commit is contained in:
@@ -50,47 +50,46 @@ def count_blocks(
|
||||
|
||||
|
||||
@router.post("/", response_model=Block, operation_id="create_block")
|
||||
def create_block(
|
||||
async def create_block(
|
||||
create_block: CreateBlock = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
block = Block(**create_block.model_dump())
|
||||
return server.block_manager.create_or_update_block(actor=actor, block=block)
|
||||
return await server.block_manager.create_or_update_block_async(actor=actor, block=block)
|
||||
|
||||
|
||||
@router.patch("/{block_id}", response_model=Block, operation_id="modify_block")
|
||||
def modify_block(
|
||||
async def modify_block(
|
||||
block_id: str,
|
||||
block_update: BlockUpdate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.block_manager.update_block_async(block_id=block_id, block_update=block_update, actor=actor)
|
||||
|
||||
|
||||
@router.delete("/{block_id}", response_model=Block, operation_id="delete_block")
|
||||
def delete_block(
|
||||
async def delete_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.delete_block(block_id=block_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.block_manager.delete_block_async(block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{block_id}", response_model=Block, operation_id="retrieve_block")
|
||||
def retrieve_block(
|
||||
async def retrieve_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
print("call get block", block_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
try:
|
||||
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
|
||||
block = await server.block_manager.get_block_by_id_async(block_id=block_id, actor=actor)
|
||||
if block is None:
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
return block
|
||||
|
||||
@@ -38,6 +38,21 @@ class BlockManager:
|
||||
block.create(session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Create a new block based on the Block schema."""
|
||||
db_block = await self.get_block_by_id_async(block.id, actor)
|
||||
if db_block:
|
||||
update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True))
|
||||
return await self.update_block_async(block.id, update_data, actor)
|
||||
else:
|
||||
async with db_registry.async_session() as session:
|
||||
data = block.model_dump(to_orm=True, exclude_none=True)
|
||||
block = BlockModel(**data, organization_id=actor.organization_id)
|
||||
await block.create_async(session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def batch_create_blocks(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
|
||||
@@ -78,6 +93,22 @@ class BlockManager:
|
||||
block.update(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Update a block by its ID with the given BlockUpdate object."""
|
||||
# Safety check for block
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(block, key, value)
|
||||
|
||||
await block.update_async(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
|
||||
@@ -87,6 +118,15 @@ class BlockManager:
|
||||
block.hard_delete(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def delete_block_async(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Delete a block by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||||
await block.hard_delete_async(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def get_blocks_async(
|
||||
@@ -161,6 +201,17 @@ class BlockManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||||
return block.to_pydantic()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
|
||||
Reference in New Issue
Block a user