feat(asyncify): migrate identities routes (#2234)
This commit is contained in:
@@ -13,7 +13,7 @@ router = APIRouter(prefix="/identities", tags=["identities"])
|
||||
|
||||
|
||||
@router.get("/", tags=["identities"], response_model=List[Identity], operation_id="list_identities")
|
||||
def list_identities(
|
||||
async def list_identities(
|
||||
name: Optional[str] = Query(None),
|
||||
project_id: Optional[str] = Query(None),
|
||||
identifier_key: Optional[str] = Query(None),
|
||||
@@ -28,9 +28,9 @@ def list_identities(
|
||||
Get a list of all identities in the database
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
identities = server.identity_manager.list_identities(
|
||||
identities = await server.identity_manager.list_identities_async(
|
||||
name=name,
|
||||
project_id=project_id,
|
||||
identifier_key=identifier_key,
|
||||
@@ -50,7 +50,7 @@ def list_identities(
|
||||
|
||||
|
||||
@router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities")
|
||||
def count_identities(
|
||||
async def count_identities(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
@@ -58,7 +58,8 @@ def count_identities(
|
||||
Get count of all identities for a user
|
||||
"""
|
||||
try:
|
||||
return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.identity_manager.size_async(actor=actor)
|
||||
except NoResultFound:
|
||||
return 0
|
||||
except HTTPException:
|
||||
@@ -68,14 +69,14 @@ def count_identities(
|
||||
|
||||
|
||||
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
|
||||
def retrieve_identity(
|
||||
async def retrieve_identity(
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.get_identity(identity_id=identity_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from letta.utils import enforce_types
|
||||
class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
def list_identities(
|
||||
async def list_identities_async(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
@@ -28,7 +28,7 @@ class IdentityManager:
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> list[PydanticIdentity]:
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
@@ -36,7 +36,7 @@ class IdentityManager:
|
||||
filters["identifier_key"] = identifier_key
|
||||
if identity_type:
|
||||
filters["identity_type"] = identity_type
|
||||
identities = IdentityModel.list(
|
||||
identities = await IdentityModel.list_async(
|
||||
db_session=session,
|
||||
query_text=name,
|
||||
before=before,
|
||||
@@ -47,9 +47,9 @@ class IdentityManager:
|
||||
return [identity.to_pydantic() for identity in identities]
|
||||
|
||||
@enforce_types
|
||||
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
with db_registry.session() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
return identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -187,15 +187,15 @@ class IdentityManager:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
async def size_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of identities for the given user.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
return IdentityModel.size(db_session=session, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
return await IdentityModel.size_async(db_session=session, actor=actor)
|
||||
|
||||
def _process_relationship(
|
||||
self,
|
||||
|
||||
@@ -3495,7 +3495,8 @@ def test_redo_concurrency_stale(server: SyncServer, default_user):
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_and_upsert_identity(server: SyncServer, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_upsert_identity(server: SyncServer, default_user, event_loop):
|
||||
identity_create = IdentityCreate(
|
||||
identifier_key="1234",
|
||||
name="caren",
|
||||
@@ -3526,7 +3527,7 @@ def test_create_and_upsert_identity(server: SyncServer, default_user):
|
||||
|
||||
identity = server.identity_manager.upsert_identity(identity=IdentityUpsert(**identity_create.model_dump()), actor=default_user)
|
||||
|
||||
identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user)
|
||||
identity = await server.identity_manager.get_identity_async(identity_id=identity.id, actor=default_user)
|
||||
assert len(identity.properties) == 1
|
||||
assert identity.properties[0].key == "age"
|
||||
assert identity.properties[0].value == 29
|
||||
@@ -3534,7 +3535,8 @@ def test_create_and_upsert_identity(server: SyncServer, default_user):
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
|
||||
def test_get_identities(server, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_identities(server, default_user):
|
||||
# Create identities to retrieve later
|
||||
user = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
|
||||
@@ -3544,14 +3546,14 @@ def test_get_identities(server, default_user):
|
||||
)
|
||||
|
||||
# Retrieve identities by different filters
|
||||
all_identities = server.identity_manager.list_identities(actor=default_user)
|
||||
all_identities = await server.identity_manager.list_identities_async(actor=default_user)
|
||||
assert len(all_identities) == 2
|
||||
|
||||
user_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user)
|
||||
user_identities = await server.identity_manager.list_identities_async(actor=default_user, identity_type=IdentityType.user)
|
||||
assert len(user_identities) == 1
|
||||
assert user_identities[0].name == user.name
|
||||
|
||||
org_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org)
|
||||
org_identities = await server.identity_manager.list_identities_async(actor=default_user, identity_type=IdentityType.org)
|
||||
assert len(org_identities) == 1
|
||||
assert org_identities[0].name == org.name
|
||||
|
||||
@@ -3573,7 +3575,7 @@ async def test_update_identity(server: SyncServer, sarah_agent, charles_agent, d
|
||||
server.identity_manager.update_identity(identity_id=identity.id, identity=update_data, actor=default_user)
|
||||
|
||||
# Retrieve the updated identity
|
||||
updated_identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user)
|
||||
updated_identity = await server.identity_manager.get_identity_async(identity_id=identity.id, actor=default_user)
|
||||
|
||||
# Assertions to verify the update
|
||||
assert updated_identity.agent_ids.sort() == update_data.agent_ids.sort()
|
||||
@@ -3604,7 +3606,7 @@ async def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
# Verify that the identity was deleted
|
||||
identities = server.identity_manager.list_identities(actor=default_user)
|
||||
identities = await server.identity_manager.list_identities_async(actor=default_user)
|
||||
assert len(identities) == 0
|
||||
|
||||
# Check that block has been detached too
|
||||
@@ -3692,7 +3694,7 @@ async def test_attach_detach_identity_from_block(server: SyncServer, default_blo
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
# Verify that the identity was deleted
|
||||
identities = server.identity_manager.list_identities(actor=default_user)
|
||||
identities = await server.identity_manager.list_identities_async(actor=default_user)
|
||||
assert len(identities) == 0
|
||||
|
||||
# Check that block has been detached too
|
||||
|
||||
Reference in New Issue
Block a user