fix: identities session management (#2555)
This commit is contained in:
@@ -3,7 +3,6 @@ from typing import List, Optional
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.block import Block as BlockModel
|
||||
@@ -60,26 +59,29 @@ class IdentityManager:
|
||||
@trace_method
|
||||
async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
|
||||
new_identity.organization_id = actor.organization_id
|
||||
await self._process_relationship_async(
|
||||
session=session,
|
||||
identity=new_identity,
|
||||
relationship_name="agents",
|
||||
model_class=AgentModel,
|
||||
item_ids=identity.agent_ids,
|
||||
allow_partial=False,
|
||||
)
|
||||
await self._process_relationship_async(
|
||||
session=session,
|
||||
identity=new_identity,
|
||||
relationship_name="blocks",
|
||||
model_class=BlockModel,
|
||||
item_ids=identity.block_ids,
|
||||
allow_partial=False,
|
||||
)
|
||||
await new_identity.create_async(session, actor=actor)
|
||||
return new_identity.to_pydantic()
|
||||
return await self._create_identity_async(db_session=session, identity=identity, actor=actor)
|
||||
|
||||
async def _create_identity_async(self, db_session, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
|
||||
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
|
||||
new_identity.organization_id = actor.organization_id
|
||||
await self._process_relationship_async(
|
||||
db_session=db_session,
|
||||
identity=new_identity,
|
||||
relationship_name="agents",
|
||||
model_class=AgentModel,
|
||||
item_ids=identity.agent_ids,
|
||||
allow_partial=False,
|
||||
)
|
||||
await self._process_relationship_async(
|
||||
db_session=db_session,
|
||||
identity=new_identity,
|
||||
relationship_name="blocks",
|
||||
model_class=BlockModel,
|
||||
item_ids=identity.block_ids,
|
||||
allow_partial=False,
|
||||
)
|
||||
await new_identity.create_async(db_session=db_session, actor=actor)
|
||||
return new_identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -94,7 +96,7 @@ class IdentityManager:
|
||||
)
|
||||
|
||||
if existing_identity is None:
|
||||
return await self.create_identity_async(identity=IdentityCreate(**identity.model_dump()), actor=actor)
|
||||
return await self._create_identity_async(db_session=session, identity=IdentityCreate(**identity.model_dump()), actor=actor)
|
||||
else:
|
||||
identity_update = IdentityUpdate(
|
||||
name=identity.name,
|
||||
@@ -104,7 +106,7 @@ class IdentityManager:
|
||||
properties=identity.properties,
|
||||
)
|
||||
return await self._update_identity_async(
|
||||
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
|
||||
db_session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
@@ -121,12 +123,12 @@ class IdentityManager:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
return await self._update_identity_async(
|
||||
session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
|
||||
db_session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
|
||||
)
|
||||
|
||||
async def _update_identity_async(
|
||||
self,
|
||||
session: Session,
|
||||
db_session,
|
||||
existing_identity: IdentityModel,
|
||||
identity: IdentityUpdate,
|
||||
actor: PydanticUser,
|
||||
@@ -149,7 +151,7 @@ class IdentityManager:
|
||||
|
||||
if identity.agent_ids is not None:
|
||||
await self._process_relationship_async(
|
||||
session=session,
|
||||
db_session=db_session,
|
||||
identity=existing_identity,
|
||||
relationship_name="agents",
|
||||
model_class=AgentModel,
|
||||
@@ -159,7 +161,7 @@ class IdentityManager:
|
||||
)
|
||||
if identity.block_ids is not None:
|
||||
await self._process_relationship_async(
|
||||
session=session,
|
||||
db_session=db_session,
|
||||
identity=existing_identity,
|
||||
relationship_name="blocks",
|
||||
model_class=BlockModel,
|
||||
@@ -167,7 +169,7 @@ class IdentityManager:
|
||||
allow_partial=False,
|
||||
replace=replace,
|
||||
)
|
||||
await existing_identity.update_async(session, actor=actor)
|
||||
await existing_identity.update_async(db_session=db_session, actor=actor)
|
||||
return existing_identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -180,7 +182,7 @@ class IdentityManager:
|
||||
if existing_identity is None:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
return await self._update_identity_async(
|
||||
session=session,
|
||||
db_session=session,
|
||||
existing_identity=existing_identity,
|
||||
identity=IdentityUpdate(properties=properties),
|
||||
actor=actor,
|
||||
@@ -213,7 +215,7 @@ class IdentityManager:
|
||||
|
||||
async def _process_relationship_async(
|
||||
self,
|
||||
session: Session,
|
||||
db_session,
|
||||
identity: PydanticIdentity,
|
||||
relationship_name: str,
|
||||
model_class,
|
||||
@@ -228,7 +230,7 @@ class IdentityManager:
|
||||
return
|
||||
|
||||
# Retrieve models for the provided IDs
|
||||
found_items = (await session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all()
|
||||
found_items = (await db_session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all()
|
||||
|
||||
# Validate all items are found if allow_partial is False
|
||||
if not allow_partial and len(found_items) != len(item_ids):
|
||||
|
||||
Reference in New Issue
Block a user