diff --git a/letta/server/rest_api/routers/v1/identities.py b/letta/server/rest_api/routers/v1/identities.py index dd48fd4e..d563aab4 100644 --- a/letta/server/rest_api/routers/v1/identities.py +++ b/letta/server/rest_api/routers/v1/identities.py @@ -13,7 +13,7 @@ router = APIRouter(prefix="/identities", tags=["identities"]) @router.get("/", tags=["identities"], response_model=List[Identity], operation_id="list_identities") -def list_identities( +async def list_identities( name: Optional[str] = Query(None), project_id: Optional[str] = Query(None), identifier_key: Optional[str] = Query(None), @@ -28,9 +28,9 @@ def list_identities( Get a list of all identities in the database """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - identities = server.identity_manager.list_identities( + identities = await server.identity_manager.list_identities_async( name=name, project_id=project_id, identifier_key=identifier_key, @@ -50,7 +50,7 @@ def list_identities( @router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities") -def count_identities( +async def count_identities( server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): @@ -58,7 +58,8 @@ def count_identities( Get count of all identities for a user """ try: - return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.identity_manager.size_async(actor=actor) except NoResultFound: return 0 except HTTPException: @@ -68,14 +69,14 @@ def count_identities( @router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity") -def retrieve_identity( +async def retrieve_identity( identity_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.identity_manager.get_identity(identity_id=identity_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor) except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 3ca05793..3531dc5c 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -17,7 +17,7 @@ from letta.utils import enforce_types class IdentityManager: @enforce_types - def list_identities( + async def list_identities_async( self, name: Optional[str] = None, project_id: Optional[str] = None, @@ -28,7 +28,7 @@ class IdentityManager: limit: Optional[int] = 50, actor: PydanticUser = None, ) -> list[PydanticIdentity]: - with db_registry.session() as session: + async with db_registry.async_session() as session: filters = {"organization_id": actor.organization_id} if project_id: filters["project_id"] = project_id @@ -36,7 +36,7 @@ class IdentityManager: filters["identifier_key"] = identifier_key if identity_type: filters["identity_type"] = identity_type - identities = IdentityModel.list( + identities = await IdentityModel.list_async( db_session=session, query_text=name, before=before, @@ -47,9 +47,9 @@ class IdentityManager: return [identity.to_pydantic() for identity in identities] @enforce_types - def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity: - with db_registry.session() as session: - identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) + async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity: + async with db_registry.async_session() as session: + identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) return identity.to_pydantic() @enforce_types @@ -187,15 +187,15 @@ class IdentityManager: session.commit() @enforce_types - def size( + async def size_async( self, actor: PydanticUser, ) -> int: """ Get the total count of identities for the given user. """ - with db_registry.session() as session: - return IdentityModel.size(db_session=session, actor=actor) + async with db_registry.async_session() as session: + return await IdentityModel.size_async(db_session=session, actor=actor) def _process_relationship( self, diff --git a/tests/test_managers.py b/tests/test_managers.py index 043a3d27..05e77017 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3495,7 +3495,8 @@ def test_redo_concurrency_stale(server: SyncServer, default_user): # ====================================================================================================================== -def test_create_and_upsert_identity(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_create_and_upsert_identity(server: SyncServer, default_user, event_loop): identity_create = IdentityCreate( identifier_key="1234", name="caren", @@ -3526,7 +3527,7 @@ def test_create_and_upsert_identity(server: SyncServer, default_user): identity = server.identity_manager.upsert_identity(identity=IdentityUpsert(**identity_create.model_dump()), actor=default_user) - identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user) + identity = await server.identity_manager.get_identity_async(identity_id=identity.id, actor=default_user) assert len(identity.properties) == 1 assert identity.properties[0].key == "age" assert identity.properties[0].value == 29 @@ -3534,7 +3535,8 @@ def test_create_and_upsert_identity(server: SyncServer, default_user): server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) -def test_get_identities(server, default_user): +@pytest.mark.asyncio +async def test_get_identities(server, default_user): # Create identities to retrieve later user = server.identity_manager.create_identity( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user @@ -3544,14 +3546,14 @@ def test_get_identities(server, default_user): ) # Retrieve identities by different filters - all_identities = server.identity_manager.list_identities(actor=default_user) + all_identities = await server.identity_manager.list_identities_async(actor=default_user) assert len(all_identities) == 2 - user_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user) + user_identities = await server.identity_manager.list_identities_async(actor=default_user, identity_type=IdentityType.user) assert len(user_identities) == 1 assert user_identities[0].name == user.name - org_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org) + org_identities = await server.identity_manager.list_identities_async(actor=default_user, identity_type=IdentityType.org) assert len(org_identities) == 1 assert org_identities[0].name == org.name @@ -3573,7 +3575,7 @@ async def test_update_identity(server: SyncServer, sarah_agent, charles_agent, d server.identity_manager.update_identity(identity_id=identity.id, identity=update_data, actor=default_user) # Retrieve the updated identity - updated_identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user) + updated_identity = await server.identity_manager.get_identity_async(identity_id=identity.id, actor=default_user) # Assertions to verify the update assert updated_identity.agent_ids.sort() == update_data.agent_ids.sort() @@ -3604,7 +3606,7 @@ async def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) # Verify that the identity was deleted - identities = server.identity_manager.list_identities(actor=default_user) + identities = await server.identity_manager.list_identities_async(actor=default_user) assert len(identities) == 0 # Check that block has been detached too @@ -3692,7 +3694,7 @@ async def test_attach_detach_identity_from_block(server: SyncServer, default_blo server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) # Verify that the identity was deleted - identities = server.identity_manager.list_identities(actor=default_user) + identities = await server.identity_manager.list_identities_async(actor=default_user) assert len(identities) == 0 # Check that block has been detached too