fix: identities session management (#2555)

This commit is contained in:
cthomas
2025-05-30 13:58:04 -07:00
committed by GitHub
parent 91cd3211ca
commit df8d285db3

View File

@@ -3,7 +3,6 @@ from typing import List, Optional
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from letta.orm.agent import Agent as AgentModel
from letta.orm.block import Block as BlockModel
@@ -60,26 +59,29 @@ class IdentityManager:
@trace_method
async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
new_identity.organization_id = actor.organization_id
await self._process_relationship_async(
session=session,
identity=new_identity,
relationship_name="agents",
model_class=AgentModel,
item_ids=identity.agent_ids,
allow_partial=False,
)
await self._process_relationship_async(
session=session,
identity=new_identity,
relationship_name="blocks",
model_class=BlockModel,
item_ids=identity.block_ids,
allow_partial=False,
)
await new_identity.create_async(session, actor=actor)
return new_identity.to_pydantic()
return await self._create_identity_async(db_session=session, identity=identity, actor=actor)
async def _create_identity_async(self, db_session, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
new_identity.organization_id = actor.organization_id
await self._process_relationship_async(
db_session=db_session,
identity=new_identity,
relationship_name="agents",
model_class=AgentModel,
item_ids=identity.agent_ids,
allow_partial=False,
)
await self._process_relationship_async(
db_session=db_session,
identity=new_identity,
relationship_name="blocks",
model_class=BlockModel,
item_ids=identity.block_ids,
allow_partial=False,
)
await new_identity.create_async(db_session=db_session, actor=actor)
return new_identity.to_pydantic()
@enforce_types
@trace_method
@@ -94,7 +96,7 @@ class IdentityManager:
)
if existing_identity is None:
return await self.create_identity_async(identity=IdentityCreate(**identity.model_dump()), actor=actor)
return await self._create_identity_async(db_session=session, identity=IdentityCreate(**identity.model_dump()), actor=actor)
else:
identity_update = IdentityUpdate(
name=identity.name,
@@ -104,7 +106,7 @@ class IdentityManager:
properties=identity.properties,
)
return await self._update_identity_async(
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
db_session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
)
@enforce_types
@@ -121,12 +123,12 @@ class IdentityManager:
raise HTTPException(status_code=403, detail="Forbidden")
return await self._update_identity_async(
session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
db_session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
)
async def _update_identity_async(
self,
session: Session,
db_session,
existing_identity: IdentityModel,
identity: IdentityUpdate,
actor: PydanticUser,
@@ -149,7 +151,7 @@ class IdentityManager:
if identity.agent_ids is not None:
await self._process_relationship_async(
session=session,
db_session=db_session,
identity=existing_identity,
relationship_name="agents",
model_class=AgentModel,
@@ -159,7 +161,7 @@ class IdentityManager:
)
if identity.block_ids is not None:
await self._process_relationship_async(
session=session,
db_session=db_session,
identity=existing_identity,
relationship_name="blocks",
model_class=BlockModel,
@@ -167,7 +169,7 @@ class IdentityManager:
allow_partial=False,
replace=replace,
)
await existing_identity.update_async(session, actor=actor)
await existing_identity.update_async(db_session=db_session, actor=actor)
return existing_identity.to_pydantic()
@enforce_types
@@ -180,7 +182,7 @@ class IdentityManager:
if existing_identity is None:
raise HTTPException(status_code=404, detail="Identity not found")
return await self._update_identity_async(
session=session,
db_session=session,
existing_identity=existing_identity,
identity=IdentityUpdate(properties=properties),
actor=actor,
@@ -213,7 +215,7 @@ class IdentityManager:
async def _process_relationship_async(
self,
session: Session,
db_session,
identity: PydanticIdentity,
relationship_name: str,
model_class,
@@ -228,7 +230,7 @@ class IdentityManager:
return
# Retrieve models for the provided IDs
found_items = (await session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all()
found_items = (await db_session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(item_ids):