feat: make identities many to many (#1085)

This commit is contained in:
cthomas
2025-02-20 16:33:24 -08:00
committed by GitHub
parent afbb5af30b
commit 31130a6d28
13 changed files with 243 additions and 83 deletions

View File

@@ -0,0 +1,89 @@
"""update identities unique constraint and properties
Revision ID: 549eff097c71
Revises: a3047a624130
Create Date: 2025-02-20 09:53:42.743105
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "549eff097c71"
down_revision: Union[str, None] = "a3047a624130"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Update unique constraint on identities table
op.drop_constraint("unique_identifier_pid_org_id", "identities", type_="unique")
op.create_unique_constraint(
"unique_identifier_without_project",
"identities",
["identifier_key", "project_id", "organization_id"],
postgresql_nulls_not_distinct=True,
)
# Add properties column to identities table
op.add_column("identities", sa.Column("properties", postgresql.JSONB, nullable=False, server_default="[]"))
# Create identities_agents table for many-to-many relationship
op.create_table(
"identities_agents",
sa.Column("identity_id", sa.String(), nullable=False),
sa.Column("agent_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["identity_id"], ["identities.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("identity_id", "agent_id"),
)
# Migrate existing relationships
# First, get existing relationships where identity_id is not null
op.execute(
"""
INSERT INTO identities_agents (identity_id, agent_id)
SELECT DISTINCT identity_id, id as agent_id
FROM agents
WHERE identity_id IS NOT NULL
"""
)
# Remove old identity_id column from agents
op.drop_column("agents", "identity_id")
op.drop_column("agents", "identifier_key")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Add back the old columns to agents
op.add_column("agents", sa.Column("identity_id", sa.String(), nullable=True))
op.add_column("agents", sa.Column("identifier_key", sa.String(), nullable=True))
# Migrate relationships back
op.execute(
"""
UPDATE agents a
SET identity_id = ia.identity_id
FROM identities_agents ia
WHERE a.id = ia.agent_id
"""
)
# Drop the many-to-many table
op.drop_table("identities_agents")
# Drop properties column
op.drop_column("identities", "properties")
# Restore old unique constraint
op.drop_constraint("unique_identifier_without_project", "identities", type_="unique")
op.create_unique_constraint("unique_identifier_pid_org_id", "identities", ["identifier_key", "project_id", "organization_id"])
# ### end Alembic commands ###

View File

@@ -4,6 +4,7 @@ from letta.orm.base import Base
from letta.orm.block import Block
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.identities_agents import IdentitiesAgents
from letta.orm.identity import Identity
from letta.orm.job import Job
from letta.orm.job_messages import JobMessage

View File

@@ -1,7 +1,7 @@
import uuid
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, Boolean, ForeignKey, Index, String
from sqlalchemy import JSON, Boolean, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block
@@ -61,14 +61,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the template the agent belongs to.")
base_template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The base template id of the agent.")
# Identity
identity_id: Mapped[Optional[str]] = mapped_column(
String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to."
)
identifier_key: Mapped[Optional[str]] = mapped_column(
String, nullable=True, doc="The identifier key of the identity the agent belongs to."
)
# Tool rules
tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.")
@@ -79,7 +71,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
identity: Mapped["Identity"] = relationship("Identity", back_populates="agents")
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
"AgentEnvironmentVariable",
back_populates="agent",
@@ -130,7 +121,13 @@ class Agent(SqlalchemyBase, OrganizationMixin):
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
doc="All passages derived created by this agent.",
)
identity: Mapped[Optional["Identity"]] = relationship("Identity", back_populates="agents")
identities: Mapped[List["Identity"]] = relationship(
"Identity",
secondary="identities_agents",
lazy="selectin",
back_populates="agents",
passive_deletes=True,
)
def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart"""
@@ -160,6 +157,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"project_id": self.project_id,
"template_id": self.template_id,
"base_template_id": self.base_template_id,
"identity_ids": [identity.id for identity in self.identities],
"message_buffer_autoclear": self.message_buffer_autoclear,
}

View File

@@ -0,0 +1,13 @@
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
class IdentitiesAgents(Base):
"""Identities may have one or many agents associated with them."""
__tablename__ = "identities_agents"
identity_id: Mapped[str] = mapped_column(String, ForeignKey("identities.id", ondelete="CASCADE"), primary_key=True)
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)

View File

@@ -2,11 +2,13 @@ import uuid
from typing import List, Optional
from sqlalchemy import String, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.identity import Identity as PydanticIdentity
from letta.schemas.identity import IdentityProperty
class Identity(SqlalchemyBase, OrganizationMixin):
@@ -14,17 +16,35 @@ class Identity(SqlalchemyBase, OrganizationMixin):
__tablename__ = "identities"
__pydantic_model__ = PydanticIdentity
__table_args__ = (UniqueConstraint("identifier_key", "project_id", "organization_id", name="unique_identifier_pid_org_id"),)
__table_args__ = (
UniqueConstraint(
"identifier_key",
"project_id",
"organization_id",
name="unique_identifier_without_project",
postgresql_nulls_not_distinct=True,
),
)
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"identity-{uuid.uuid4()}")
identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.")
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.")
identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.")
project_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The project id of the identity.")
properties: Mapped[List["IdentityProperty"]] = mapped_column(
JSONB, nullable=False, default=list, doc="List of properties associated with the identity"
)
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities")
agents: Mapped[List["Agent"]] = relationship("Agent", lazy="selectin", back_populates="identity")
agents: Mapped[List["Agent"]] = relationship(
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
)
@property
def agent_ids(self) -> List[str]:
"""Get just the agent IDs without loading the full agent objects"""
return [agent.id for agent in self.agents]
def to_pydantic(self) -> PydanticIdentity:
state = {
@@ -33,7 +53,8 @@ class Identity(SqlalchemyBase, OrganizationMixin):
"name": self.name,
"identity_type": self.identity_type,
"project_id": self.project_id,
"agents": [agent.to_pydantic() for agent in self.agents],
"agent_ids": self.agent_ids,
"organization_id": self.organization_id,
"properties": self.properties,
}
return self.__pydantic_model__(**state)
return PydanticIdentity(**state)

View File

@@ -68,6 +68,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
access_type: AccessType = AccessType.ORGANIZATION,
join_model: Optional[Base] = None,
join_conditions: Optional[Union[Tuple, List]] = None,
identifier_keys: Optional[List[str]] = None,
**kwargs,
) -> List["SqlalchemyBase"]:
"""
@@ -143,6 +144,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# Group by primary key and all necessary columns to avoid JSON comparison
query = query.group_by(cls.id)
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
# Apply filtering logic from kwargs
for key, value in kwargs.items():
if "." in key:

View File

@@ -83,9 +83,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
# Identity
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
identity_ids: List[str] = Field([], description="The ids of the identities associated with this agent.")
# An advanced configuration that makes it so this agent does not remember any previous messages
message_buffer_autoclear: bool = Field(
@@ -161,7 +159,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
identity_ids: Optional[List[str]] = Field(None, description="The ids of the identities associated with this agent.")
message_buffer_autoclear: bool = Field(
False,
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
@@ -236,7 +234,7 @@ class UpdateAgent(BaseModel):
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
identity_ids: Optional[List[str]] = Field(None, description="The ids of the identities associated with this agent.")
message_buffer_autoclear: Optional[bool] = Field(
None,
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",

View File

@@ -1,9 +1,8 @@
from enum import Enum
from typing import List, Optional
from typing import List, Optional, Union
from pydantic import Field
from letta.schemas.agent import AgentState
from letta.schemas.letta_base import LettaBase
@@ -17,17 +16,38 @@ class IdentityType(str, Enum):
other = "other"
class IdentityPropertyType(str, Enum):
"""
Enum to represent the type of the identity property.
"""
string = "string"
number = "number"
boolean = "boolean"
json = "json"
class IdentityBase(LettaBase):
__id_prefix__ = "identity"
class IdentityProperty(LettaBase):
"""A property of an identity"""
key: str = Field(..., description="The key of the property")
value: Union[str, int, float, bool, dict] = Field(..., description="The value of the property")
type: IdentityPropertyType = Field(..., description="The type of the property")
class Identity(IdentityBase):
id: str = IdentityBase.generate_id_field()
identifier_key: str = Field(..., description="External, user-generated identifier key of the identity.")
name: str = Field(..., description="The name of the identity.")
identity_type: IdentityType = Field(..., description="The type of the identity.")
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
agents: List[AgentState] = Field(..., description="The agents associated with the identity.")
agent_ids: List[str] = Field(..., description="The IDs of the agents associated with the identity.")
organization_id: Optional[str] = Field(None, description="The organization id of the user")
properties: List[IdentityProperty] = Field(default_factory=list, description="List of properties associated with the identity")
class IdentityCreate(LettaBase):
@@ -36,9 +56,12 @@ class IdentityCreate(LettaBase):
identity_type: IdentityType = Field(..., description="The type of the identity.")
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.")
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
class IdentityUpdate(LettaBase):
identifier_key: Optional[str] = Field(None, description="External, user-generated identifier key of the identity.")
name: Optional[str] = Field(None, description="The name of the identity.")
identity_type: Optional[IdentityType] = Field(None, description="The type of the identity.")
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.")
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")

View File

@@ -1,3 +1,4 @@
import traceback
from datetime import datetime
from typing import Annotated, List, Optional
@@ -51,7 +52,7 @@ def list_agents(
project_id: Optional[str] = Query(None, description="Search agents by project id"),
template_id: Optional[str] = Query(None, description="Search agents by template id"),
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
identifier_key: Optional[str] = Query(None, description="Search agents by identifier key"),
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
):
"""
List all agents associated with a given user.
@@ -67,7 +68,6 @@ def list_agents(
"project_id": project_id,
"template_id": template_id,
"base_template_id": base_template_id,
"identifier_key": identifier_key,
}.items()
if value is not None
}
@@ -81,6 +81,7 @@ def list_agents(
query_text=query_text,
tags=tags,
match_all_tags=match_all_tags,
identifier_keys=identifier_keys,
**kwargs,
)
return agents
@@ -119,8 +120,12 @@ def create_agent(
"""
Create a new agent with the specified configuration.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.create_agent(agent, actor=actor)
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.create_agent(agent, actor=actor)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@router.patch("/{agent_id}", response_model=AgentState, operation_id="modify_agent")

View File

@@ -16,6 +16,7 @@ router = APIRouter(prefix="/identities", tags=["identities"])
def list_identities(
name: Optional[str] = Query(None),
project_id: Optional[str] = Query(None),
identifier_key: Optional[str] = Query(None),
identity_type: Optional[IdentityType] = Query(None),
before: Optional[str] = Query(None),
after: Optional[str] = Query(None),
@@ -30,7 +31,14 @@ def list_identities(
actor = server.user_manager.get_user_or_default(user_id=user_id)
identities = server.identity_manager.list_identities(
name=name, project_id=project_id, identity_type=identity_type, before=before, after=after, limit=limit, actor=actor
name=name,
project_id=project_id,
identifier_key=identifier_key,
identity_type=identity_type,
before=before,
after=after,
limit=limit,
actor=actor,
)
except HTTPException:
raise
@@ -39,13 +47,13 @@ def list_identities(
return identities
@router.get("/{identifier_key}", tags=["identities"], response_model=Identity, operation_id="get_identity_from_identifier_key")
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
def retrieve_identity(
identifier_key: str,
identity_id: str,
server: "SyncServer" = Depends(get_letta_server),
):
try:
return server.identity_manager.get_identity_from_identifier_key(identifier_key=identifier_key)
return server.identity_manager.get_identity(identity_id=identity_id)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -82,25 +90,25 @@ def upsert_identity(
raise HTTPException(status_code=500, detail=f"{e}")
@router.patch("/{identifier_key}", tags=["identities"], response_model=Identity, operation_id="update_identity")
@router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity")
def modify_identity(
identifier_key: str,
identity_id: str,
identity: IdentityUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.identity_manager.update_identity_by_key(identifier_key=identifier_key, identity=identity, actor=actor)
return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.delete("/{identifier_key}", tags=["identities"], operation_id="delete_identity")
@router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity")
def delete_identity(
identifier_key: str,
identity_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
@@ -108,4 +116,4 @@ def delete_identity(
Delete an identity by its identifier key
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.identity_manager.delete_identity_by_key(identifier_key=identifier_key, actor=actor)
server.identity_manager.delete_identity(identity_id=identity_id, actor=actor)

View File

@@ -11,6 +11,7 @@ from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import AgentPassage, AgentsTags
from letta.orm import Block as BlockModel
from letta.orm import Identity as IdentityModel
from letta.orm import Source as SourceModel
from letta.orm import SourcePassage, SourcesAgents
from letta.orm import Tool as ToolModel
@@ -34,7 +35,6 @@ from letta.schemas.user import User as PydanticUser
from letta.serialize_schemas import SerializedAgentSchema
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
_process_identity,
_process_relationship,
_process_tags,
check_supports_structured_output,
@@ -138,6 +138,7 @@ class AgentManager:
tool_ids=tool_ids,
source_ids=agent_create.source_ids or [],
tags=agent_create.tags or [],
identity_ids=agent_create.identity_ids or [],
description=agent_create.description,
metadata=agent_create.metadata,
tool_rules=tool_rules,
@@ -145,7 +146,6 @@ class AgentManager:
project_id=agent_create.project_id,
template_id=agent_create.template_id,
base_template_id=agent_create.base_template_id,
identifier_key=agent_create.identifier_key,
message_buffer_autoclear=agent_create.message_buffer_autoclear,
)
@@ -203,13 +203,13 @@ class AgentManager:
tool_ids: List[str],
source_ids: List[str],
tags: List[str],
identity_ids: List[str],
description: Optional[str] = None,
metadata: Optional[Dict] = None,
tool_rules: Optional[List[PydanticToolRule]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
base_template_id: Optional[str] = None,
identifier_key: Optional[str] = None,
message_buffer_autoclear: bool = False,
) -> PydanticAgentState:
"""Create a new agent."""
@@ -237,9 +237,7 @@ class AgentManager:
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
_process_tags(new_agent, tags, replace=True)
if identifier_key is not None:
identity = self.identity_manager.get_identity_from_identifier_key(identifier_key)
_process_identity(new_agent, identifier_key, identity)
_process_relationship(session, new_agent, "identities", IdentityModel, identity_ids, replace=True)
new_agent.create(session, actor=actor)
@@ -313,9 +311,8 @@ class AgentManager:
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
if agent_update.tags is not None:
_process_tags(agent, agent_update.tags, replace=True)
if agent_update.identifier_key is not None:
identity = self.identity_manager.get_identity_from_identifier_key(agent_update.identifier_key)
_process_identity(agent, agent_update.identifier_key, identity)
if agent_update.identity_ids is not None:
_process_relationship(session, agent, "identities", IdentityModel, agent_update.identity_ids, replace=True)
# Commit and refresh the agent
agent.update(session, actor=actor)
@@ -333,6 +330,7 @@ class AgentManager:
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
query_text: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
**kwargs,
) -> List[PydanticAgentState]:
"""
@@ -348,6 +346,7 @@ class AgentManager:
match_all_tags=match_all_tags,
organization_id=actor.organization_id if actor else None,
query_text=query_text,
identifier_keys=identifier_keys,
**kwargs,
)

View File

@@ -11,7 +11,6 @@ from letta.orm.errors import NoResultFound
from letta.prompts import gpt_system
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole
from letta.schemas.identity import Identity
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate, TextContent
from letta.schemas.tool_rule import ToolRule
@@ -85,20 +84,6 @@ def _process_tags(agent: AgentModel, tags: List[str], replace=True):
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
def _process_identity(agent: AgentModel, identifier_key: str, identity: Identity):
"""
Handles identity for an agent.
Args:
agent: The AgentModel instance.
identifier_key: The identifier key of the identity to set or update.
identity: The Identity object to set or update.
"""
agent.identifier_key = identifier_key
agent.identity = identity
agent.identity_id = identity.id
def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
if system is None:
# TODO: don't hardcode

View File

@@ -1,6 +1,7 @@
from typing import List, Optional
from fastapi import HTTPException
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from letta.orm.agent import Agent as AgentModel
@@ -23,6 +24,7 @@ class IdentityManager:
self,
name: Optional[str] = None,
project_id: Optional[str] = None,
identifier_key: Optional[str] = None,
identity_type: Optional[IdentityType] = None,
before: Optional[str] = None,
after: Optional[str] = None,
@@ -33,6 +35,8 @@ class IdentityManager:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
if identifier_key:
filters["identifier_key"] = identifier_key
if identity_type:
filters["identity_type"] = identity_type
identities = IdentityModel.list(
@@ -46,9 +50,9 @@ class IdentityManager:
return [identity.to_pydantic() for identity in identities]
@enforce_types
def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity:
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
return identity.to_pydantic()
@enforce_types
@@ -68,44 +72,56 @@ class IdentityManager:
identifier_key=identity.identifier_key,
project_id=identity.project_id,
organization_id=actor.organization_id,
actor=actor,
)
if existing_identity is None:
return self.create_identity(identity=identity, actor=actor)
else:
if existing_identity.identifier_key != identity.identifier_key:
raise HTTPException(status_code=400, detail="Identifier key is an immutable field")
if existing_identity.project_id != identity.project_id:
raise HTTPException(status_code=400, detail="Project id is an immutable field")
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
return self.update_identity_by_key(identity.identifier_key, identity_update, actor, replace=True)
return self._update_identity(
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
)
@enforce_types
def update_identity_by_key(
self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
) -> PydanticIdentity:
def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity:
with self.session_maker() as session:
try:
existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Identity not found")
if existing_identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
existing_identity.name = identity.name if identity.name is not None else existing_identity.name
existing_identity.identity_type = (
identity.identity_type if identity.identity_type is not None else existing_identity.identity_type
return self._update_identity(
session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
)
self._process_agent_relationship(
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
)
existing_identity.update(session, actor=actor)
return existing_identity.to_pydantic()
def _update_identity(
self,
session: Session,
existing_identity: IdentityModel,
identity: IdentityUpdate,
actor: PydanticUser,
replace: bool = False,
) -> PydanticIdentity:
if identity.identifier_key is not None:
existing_identity.identifier_key = identity.identifier_key
if identity.name is not None:
existing_identity.name = identity.name
if identity.identity_type is not None:
existing_identity.identity_type = identity.identity_type
self._process_agent_relationship(
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
)
existing_identity.update(session, actor=actor)
return existing_identity.to_pydantic()
@enforce_types
def delete_identity_by_key(self, identifier_key: str, actor: PydanticUser) -> None:
def delete_identity(self, identity_id: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
identity = IdentityModel.read(db_session=session, identifier=identity_id)
if identity is None:
raise HTTPException(status_code=404, detail="Identity not found")
if identity.organization_id != actor.organization_id: