feat(asyncify): migrate identities routes (#2234)

This commit is contained in:
cthomas
2025-05-18 20:45:57 -07:00
committed by GitHub
parent 53e9fcf5f0
commit cec4b89c27
3 changed files with 29 additions and 26 deletions

View File

@@ -13,7 +13,7 @@ router = APIRouter(prefix="/identities", tags=["identities"])
@router.get("/", tags=["identities"], response_model=List[Identity], operation_id="list_identities")
def list_identities(
async def list_identities(
name: Optional[str] = Query(None),
project_id: Optional[str] = Query(None),
identifier_key: Optional[str] = Query(None),
@@ -28,9 +28,9 @@ def list_identities(
Get a list of all identities in the database
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
identities = server.identity_manager.list_identities(
identities = await server.identity_manager.list_identities_async(
name=name,
project_id=project_id,
identifier_key=identifier_key,
@@ -50,7 +50,7 @@ def list_identities(
@router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities")
def count_identities(
async def count_identities(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
@@ -58,7 +58,8 @@ def count_identities(
Get count of all identities for a user
"""
try:
return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.identity_manager.size_async(actor=actor)
except NoResultFound:
return 0
except HTTPException:
@@ -68,14 +69,14 @@ def count_identities(
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
def retrieve_identity(
async def retrieve_identity(
identity_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.get_identity(identity_id=identity_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -17,7 +17,7 @@ from letta.utils import enforce_types
class IdentityManager:
@enforce_types
def list_identities(
async def list_identities_async(
self,
name: Optional[str] = None,
project_id: Optional[str] = None,
@@ -28,7 +28,7 @@ class IdentityManager:
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> list[PydanticIdentity]:
with db_registry.session() as session:
async with db_registry.async_session() as session:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
@@ -36,7 +36,7 @@ class IdentityManager:
filters["identifier_key"] = identifier_key
if identity_type:
filters["identity_type"] = identity_type
identities = IdentityModel.list(
identities = await IdentityModel.list_async(
db_session=session,
query_text=name,
before=before,
@@ -47,9 +47,9 @@ class IdentityManager:
return [identity.to_pydantic() for identity in identities]
@enforce_types
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
with db_registry.session() as session:
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
return identity.to_pydantic()
@enforce_types
@@ -187,15 +187,15 @@ class IdentityManager:
session.commit()
@enforce_types
def size(
async def size_async(
self,
actor: PydanticUser,
) -> int:
"""
Get the total count of identities for the given user.
"""
with db_registry.session() as session:
return IdentityModel.size(db_session=session, actor=actor)
async with db_registry.async_session() as session:
return await IdentityModel.size_async(db_session=session, actor=actor)
def _process_relationship(
self,

View File

@@ -3495,7 +3495,8 @@ def test_redo_concurrency_stale(server: SyncServer, default_user):
# ======================================================================================================================
def test_create_and_upsert_identity(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_create_and_upsert_identity(server: SyncServer, default_user, event_loop):
identity_create = IdentityCreate(
identifier_key="1234",
name="caren",
@@ -3526,7 +3527,7 @@ def test_create_and_upsert_identity(server: SyncServer, default_user):
identity = server.identity_manager.upsert_identity(identity=IdentityUpsert(**identity_create.model_dump()), actor=default_user)
identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user)
identity = await server.identity_manager.get_identity_async(identity_id=identity.id, actor=default_user)
assert len(identity.properties) == 1
assert identity.properties[0].key == "age"
assert identity.properties[0].value == 29
@@ -3534,7 +3535,8 @@ def test_create_and_upsert_identity(server: SyncServer, default_user):
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
def test_get_identities(server, default_user):
@pytest.mark.asyncio
async def test_get_identities(server, default_user):
# Create identities to retrieve later
user = server.identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
@@ -3544,14 +3546,14 @@ def test_get_identities(server, default_user):
)
# Retrieve identities by different filters
all_identities = server.identity_manager.list_identities(actor=default_user)
all_identities = await server.identity_manager.list_identities_async(actor=default_user)
assert len(all_identities) == 2
user_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user)
user_identities = await server.identity_manager.list_identities_async(actor=default_user, identity_type=IdentityType.user)
assert len(user_identities) == 1
assert user_identities[0].name == user.name
org_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org)
org_identities = await server.identity_manager.list_identities_async(actor=default_user, identity_type=IdentityType.org)
assert len(org_identities) == 1
assert org_identities[0].name == org.name
@@ -3573,7 +3575,7 @@ async def test_update_identity(server: SyncServer, sarah_agent, charles_agent, d
server.identity_manager.update_identity(identity_id=identity.id, identity=update_data, actor=default_user)
# Retrieve the updated identity
updated_identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user)
updated_identity = await server.identity_manager.get_identity_async(identity_id=identity.id, actor=default_user)
# Assertions to verify the update
assert updated_identity.agent_ids.sort() == update_data.agent_ids.sort()
@@ -3604,7 +3606,7 @@ async def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
# Verify that the identity was deleted
identities = server.identity_manager.list_identities(actor=default_user)
identities = await server.identity_manager.list_identities_async(actor=default_user)
assert len(identities) == 0
# Check that block has been detached too
@@ -3692,7 +3694,7 @@ async def test_attach_detach_identity_from_block(server: SyncServer, default_blo
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
# Verify that the identity was deleted
identities = server.identity_manager.list_identities(actor=default_user)
identities = await server.identity_manager.list_identities_async(actor=default_user)
assert len(identities) == 0
# Check that block has been detached too