Files
letta-server/letta/services/identity_manager.py
2025-02-19 22:16:21 -08:00

141 lines
6.3 KiB
Python

from typing import List, Optional
from fastapi import HTTPException
from sqlalchemy.orm import Session
from letta.orm.agent import Agent as AgentModel
from letta.orm.identity import Identity as IdentityModel
from letta.schemas.identity import Identity as PydanticIdentity
from letta.schemas.identity import IdentityCreate, IdentityType, IdentityUpdate
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class IdentityManager:
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def list_identities(
self,
name: Optional[str] = None,
project_id: Optional[str] = None,
identity_type: Optional[IdentityType] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> list[PydanticIdentity]:
with self.session_maker() as session:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
if identity_type:
filters["identity_type"] = identity_type
identities = IdentityModel.list(
db_session=session,
query_text=name,
before=before,
after=after,
limit=limit,
**filters,
)
return [identity.to_pydantic() for identity in identities]
@enforce_types
def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity:
with self.session_maker() as session:
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
return identity.to_pydantic()
@enforce_types
def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids"}, exclude_unset=True))
new_identity.organization_id = actor.organization_id
self._process_agent_relationship(session=session, identity=new_identity, agent_ids=identity.agent_ids, allow_partial=False)
new_identity.create(session, actor=actor)
return new_identity.to_pydantic()
@enforce_types
def upsert_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
existing_identity = IdentityModel.read(
db_session=session,
identifier_key=identity.identifier_key,
project_id=identity.project_id,
organization_id=actor.organization_id,
)
if existing_identity is None:
return self.create_identity(identity=identity, actor=actor)
else:
if existing_identity.identifier_key != identity.identifier_key:
raise HTTPException(status_code=400, detail="Identifier key is an immutable field")
if existing_identity.project_id != identity.project_id:
raise HTTPException(status_code=400, detail="Project id is an immutable field")
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
return self.update_identity_by_key(identity.identifier_key, identity_update, actor, replace=True)
@enforce_types
def update_identity_by_key(
self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
) -> PydanticIdentity:
with self.session_maker() as session:
try:
existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
except NoResultFound:
raise HTTPException(status_code=404, detail="Identity not found")
if existing_identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
existing_identity.name = identity.name if identity.name is not None else existing_identity.name
existing_identity.identity_type = (
identity.identity_type if identity.identity_type is not None else existing_identity.identity_type
)
self._process_agent_relationship(
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
)
existing_identity.update(session, actor=actor)
return existing_identity.to_pydantic()
@enforce_types
def delete_identity_by_key(self, identifier_key: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
if identity is None:
raise HTTPException(status_code=404, detail="Identity not found")
if identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
session.delete(identity)
session.commit()
def _process_agent_relationship(
self, session: Session, identity: IdentityModel, agent_ids: List[str], allow_partial=False, replace=True
):
current_relationship = getattr(identity, "agents", [])
if not agent_ids:
if replace:
setattr(identity, "agents", [])
return
# Retrieve models for the provided IDs
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(agent_ids):
missing = set(agent_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in agents: {missing}")
if replace:
# Replace the relationship
setattr(identity, "agents", found_items)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}
new_items = [item for item in found_items if item.id not in current_ids]
current_relationship.extend(new_items)