From df8d285db36de785c4c3908accb3b3bd6a96392a Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 30 May 2025 13:58:04 -0700 Subject: [PATCH] fix: identities session management (#2555) --- letta/services/identity_manager.py | 64 +++++++++++++++--------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 97fc9765..f91854fd 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -3,7 +3,6 @@ from typing import List, Optional from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.exc import NoResultFound -from sqlalchemy.orm import Session from letta.orm.agent import Agent as AgentModel from letta.orm.block import Block as BlockModel @@ -60,26 +59,29 @@ class IdentityManager: @trace_method async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: async with db_registry.async_session() as session: - new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True)) - new_identity.organization_id = actor.organization_id - await self._process_relationship_async( - session=session, - identity=new_identity, - relationship_name="agents", - model_class=AgentModel, - item_ids=identity.agent_ids, - allow_partial=False, - ) - await self._process_relationship_async( - session=session, - identity=new_identity, - relationship_name="blocks", - model_class=BlockModel, - item_ids=identity.block_ids, - allow_partial=False, - ) - await new_identity.create_async(session, actor=actor) - return new_identity.to_pydantic() + return await self._create_identity_async(db_session=session, identity=identity, actor=actor) + + async def _create_identity_async(self, db_session, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: + new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True)) + new_identity.organization_id = actor.organization_id + await self._process_relationship_async( + db_session=db_session, + identity=new_identity, + relationship_name="agents", + model_class=AgentModel, + item_ids=identity.agent_ids, + allow_partial=False, + ) + await self._process_relationship_async( + db_session=db_session, + identity=new_identity, + relationship_name="blocks", + model_class=BlockModel, + item_ids=identity.block_ids, + allow_partial=False, + ) + await new_identity.create_async(db_session=db_session, actor=actor) + return new_identity.to_pydantic() @enforce_types @trace_method @@ -94,7 +96,7 @@ class IdentityManager: ) if existing_identity is None: - return await self.create_identity_async(identity=IdentityCreate(**identity.model_dump()), actor=actor) + return await self._create_identity_async(db_session=session, identity=IdentityCreate(**identity.model_dump()), actor=actor) else: identity_update = IdentityUpdate( name=identity.name, @@ -104,7 +106,7 @@ class IdentityManager: properties=identity.properties, ) return await self._update_identity_async( - session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True + db_session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True ) @enforce_types @@ -121,12 +123,12 @@ class IdentityManager: raise HTTPException(status_code=403, detail="Forbidden") return await self._update_identity_async( - session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace + db_session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace ) async def _update_identity_async( self, - session: Session, + db_session, existing_identity: IdentityModel, identity: IdentityUpdate, actor: PydanticUser, @@ -149,7 +151,7 @@ class IdentityManager: if identity.agent_ids is not None: await self._process_relationship_async( - session=session, + db_session=db_session, identity=existing_identity, relationship_name="agents", model_class=AgentModel, @@ -159,7 +161,7 @@ class IdentityManager: ) if identity.block_ids is not None: await self._process_relationship_async( - session=session, + db_session=db_session, identity=existing_identity, relationship_name="blocks", model_class=BlockModel, @@ -167,7 +169,7 @@ class IdentityManager: allow_partial=False, replace=replace, ) - await existing_identity.update_async(session, actor=actor) + await existing_identity.update_async(db_session=db_session, actor=actor) return existing_identity.to_pydantic() @enforce_types @@ -180,7 +182,7 @@ class IdentityManager: if existing_identity is None: raise HTTPException(status_code=404, detail="Identity not found") return await self._update_identity_async( - session=session, + db_session=session, existing_identity=existing_identity, identity=IdentityUpdate(properties=properties), actor=actor, @@ -213,7 +215,7 @@ class IdentityManager: async def _process_relationship_async( self, - session: Session, + db_session, identity: PydanticIdentity, relationship_name: str, model_class, @@ -228,7 +230,7 @@ class IdentityManager: return # Retrieve models for the provided IDs - found_items = (await session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all() + found_items = (await db_session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all() # Validate all items are found if allow_partial is False if not allow_partial and len(found_items) != len(item_ids):