diff --git a/alembic/versions/549eff097c71_update_identities_unique_constraint_and_.py b/alembic/versions/549eff097c71_update_identities_unique_constraint_and_.py new file mode 100644 index 00000000..97a72543 --- /dev/null +++ b/alembic/versions/549eff097c71_update_identities_unique_constraint_and_.py @@ -0,0 +1,89 @@ +"""update identities unique constraint and properties + +Revision ID: 549eff097c71 +Revises: a3047a624130 +Create Date: 2025-02-20 09:53:42.743105 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "549eff097c71" +down_revision: Union[str, None] = "a3047a624130" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Update unique constraint on identities table + op.drop_constraint("unique_identifier_pid_org_id", "identities", type_="unique") + op.create_unique_constraint( + "unique_identifier_without_project", + "identities", + ["identifier_key", "project_id", "organization_id"], + postgresql_nulls_not_distinct=True, + ) + + # Add properties column to identities table + op.add_column("identities", sa.Column("properties", postgresql.JSONB, nullable=False, server_default="[]")) + + # Create identities_agents table for many-to-many relationship + op.create_table( + "identities_agents", + sa.Column("identity_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["identity_id"], ["identities.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("identity_id", "agent_id"), + ) + + # Migrate existing relationships + # First, get existing relationships where identity_id is not null + op.execute( + """ + INSERT INTO identities_agents (identity_id, agent_id) + SELECT DISTINCT identity_id, id as agent_id + FROM agents + WHERE identity_id IS NOT NULL + """ + ) + + # Remove old identity_id column from agents + op.drop_column("agents", "identity_id") + op.drop_column("agents", "identifier_key") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Add back the old columns to agents + op.add_column("agents", sa.Column("identity_id", sa.String(), nullable=True)) + op.add_column("agents", sa.Column("identifier_key", sa.String(), nullable=True)) + + # Migrate relationships back + op.execute( + """ + UPDATE agents a + SET identity_id = ia.identity_id + FROM identities_agents ia + WHERE a.id = ia.agent_id + """ + ) + + # Drop the many-to-many table + op.drop_table("identities_agents") + + # Drop properties column + op.drop_column("identities", "properties") + + # Restore old unique constraint + op.drop_constraint("unique_identifier_without_project", "identities", type_="unique") + op.create_unique_constraint("unique_identifier_pid_org_id", "identities", ["identifier_key", "project_id", "organization_id"]) + # ### end Alembic commands ### diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 28feb237..10c25253 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -4,6 +4,7 @@ from letta.orm.base import Base from letta.orm.block import Block from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata +from letta.orm.identities_agents import IdentitiesAgents from letta.orm.identity import Identity from letta.orm.job import Job from letta.orm.job_messages import JobMessage diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 5efb3366..59f7f1ff 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -1,7 +1,7 @@ import uuid from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, Boolean, ForeignKey, Index, String +from sqlalchemy import JSON, Boolean, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.block import Block @@ -61,14 +61,6 @@ class Agent(SqlalchemyBase, OrganizationMixin): template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the template the agent belongs to.") base_template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The base template id of the agent.") - # Identity - identity_id: Mapped[Optional[str]] = mapped_column( - String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to." - ) - identifier_key: Mapped[Optional[str]] = mapped_column( - String, nullable=True, doc="The identifier key of the identity the agent belongs to." - ) - # Tool rules tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.") @@ -79,7 +71,6 @@ class Agent(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") - identity: Mapped["Identity"] = relationship("Identity", back_populates="agents") tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship( "AgentEnvironmentVariable", back_populates="agent", @@ -130,7 +121,13 @@ class Agent(SqlalchemyBase, OrganizationMixin): viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship doc="All passages derived created by this agent.", ) - identity: Mapped[Optional["Identity"]] = relationship("Identity", back_populates="agents") + identities: Mapped[List["Identity"]] = relationship( + "Identity", + secondary="identities_agents", + lazy="selectin", + back_populates="agents", + passive_deletes=True, + ) def to_pydantic(self) -> PydanticAgentState: """converts to the basic pydantic model counterpart""" @@ -160,6 +157,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): "project_id": self.project_id, "template_id": self.template_id, "base_template_id": self.base_template_id, + "identity_ids": [identity.id for identity in self.identities], "message_buffer_autoclear": self.message_buffer_autoclear, } diff --git a/letta/orm/identities_agents.py b/letta/orm/identities_agents.py new file mode 100644 index 00000000..a8958691 --- /dev/null +++ b/letta/orm/identities_agents.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.base import Base + + +class IdentitiesAgents(Base): + """Identities may have one or many agents associated with them.""" + + __tablename__ = "identities_agents" + + identity_id: Mapped[str] = mapped_column(String, ForeignKey("identities.id", ondelete="CASCADE"), primary_key=True) + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True) diff --git a/letta/orm/identity.py b/letta/orm/identity.py index 4a7cfefd..1a8058e5 100644 --- a/letta/orm/identity.py +++ b/letta/orm/identity.py @@ -2,11 +2,13 @@ import uuid from typing import List, Optional from sqlalchemy import String, UniqueConstraint +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.identity import Identity as PydanticIdentity +from letta.schemas.identity import IdentityProperty class Identity(SqlalchemyBase, OrganizationMixin): @@ -14,17 +16,35 @@ class Identity(SqlalchemyBase, OrganizationMixin): __tablename__ = "identities" __pydantic_model__ = PydanticIdentity - __table_args__ = (UniqueConstraint("identifier_key", "project_id", "organization_id", name="unique_identifier_pid_org_id"),) + __table_args__ = ( + UniqueConstraint( + "identifier_key", + "project_id", + "organization_id", + name="unique_identifier_without_project", + postgresql_nulls_not_distinct=True, + ), + ) id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"identity-{uuid.uuid4()}") identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.") name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.") identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.") project_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The project id of the identity.") + properties: Mapped[List["IdentityProperty"]] = mapped_column( + JSONB, nullable=False, default=list, doc="List of properties associated with the identity" + ) # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="identities") - agents: Mapped[List["Agent"]] = relationship("Agent", lazy="selectin", back_populates="identity") + agents: Mapped[List["Agent"]] = relationship( + "Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities" + ) + + @property + def agent_ids(self) -> List[str]: + """Get just the agent IDs without loading the full agent objects""" + return [agent.id for agent in self.agents] def to_pydantic(self) -> PydanticIdentity: state = { @@ -33,7 +53,8 @@ class Identity(SqlalchemyBase, OrganizationMixin): "name": self.name, "identity_type": self.identity_type, "project_id": self.project_id, - "agents": [agent.to_pydantic() for agent in self.agents], + "agent_ids": self.agent_ids, + "organization_id": self.organization_id, + "properties": self.properties, } - - return self.__pydantic_model__(**state) + return PydanticIdentity(**state) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 8cdd686a..0bca5dd3 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -68,6 +68,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): access_type: AccessType = AccessType.ORGANIZATION, join_model: Optional[Base] = None, join_conditions: Optional[Union[Tuple, List]] = None, + identifier_keys: Optional[List[str]] = None, **kwargs, ) -> List["SqlalchemyBase"]: """ @@ -143,6 +144,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): # Group by primary key and all necessary columns to avoid JSON comparison query = query.group_by(cls.id) + if identifier_keys and hasattr(cls, "identities"): + query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) + # Apply filtering logic from kwargs for key, value in kwargs.items(): if "." in key: diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index da06e6b5..50cbe030 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -83,9 +83,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True): project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") - - # Identity - identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.") + identity_ids: List[str] = Field([], description="The ids of the identities associated with this agent.") # An advanced configuration that makes it so this agent does not remember any previous messages message_buffer_autoclear: bool = Field( @@ -161,7 +159,7 @@ class CreateAgent(BaseModel, validate_assignment=True): # project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") - identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.") + identity_ids: Optional[List[str]] = Field(None, description="The ids of the identities associated with this agent.") message_buffer_autoclear: bool = Field( False, description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", @@ -236,7 +234,7 @@ class UpdateAgent(BaseModel): project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") - identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.") + identity_ids: Optional[List[str]] = Field(None, description="The ids of the identities associated with this agent.") message_buffer_autoclear: Optional[bool] = Field( None, description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", diff --git a/letta/schemas/identity.py b/letta/schemas/identity.py index 204826a9..5b44fa67 100644 --- a/letta/schemas/identity.py +++ b/letta/schemas/identity.py @@ -1,9 +1,8 @@ from enum import Enum -from typing import List, Optional +from typing import List, Optional, Union from pydantic import Field -from letta.schemas.agent import AgentState from letta.schemas.letta_base import LettaBase @@ -17,17 +16,38 @@ class IdentityType(str, Enum): other = "other" +class IdentityPropertyType(str, Enum): + """ + Enum to represent the type of the identity property. + """ + + string = "string" + number = "number" + boolean = "boolean" + json = "json" + + class IdentityBase(LettaBase): __id_prefix__ = "identity" +class IdentityProperty(LettaBase): + """A property of an identity""" + + key: str = Field(..., description="The key of the property") + value: Union[str, int, float, bool, dict] = Field(..., description="The value of the property") + type: IdentityPropertyType = Field(..., description="The type of the property") + + class Identity(IdentityBase): id: str = IdentityBase.generate_id_field() identifier_key: str = Field(..., description="External, user-generated identifier key of the identity.") name: str = Field(..., description="The name of the identity.") identity_type: IdentityType = Field(..., description="The type of the identity.") project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.") - agents: List[AgentState] = Field(..., description="The agents associated with the identity.") + agent_ids: List[str] = Field(..., description="The IDs of the agents associated with the identity.") + organization_id: Optional[str] = Field(None, description="The organization id of the user") + properties: List[IdentityProperty] = Field(default_factory=list, description="List of properties associated with the identity") class IdentityCreate(LettaBase): @@ -36,9 +56,12 @@ class IdentityCreate(LettaBase): identity_type: IdentityType = Field(..., description="The type of the identity.") project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.") agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.") + properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.") class IdentityUpdate(LettaBase): + identifier_key: Optional[str] = Field(None, description="External, user-generated identifier key of the identity.") name: Optional[str] = Field(None, description="The name of the identity.") identity_type: Optional[IdentityType] = Field(None, description="The type of the identity.") agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.") + properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index ed8a8178..4c9d93f6 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1,3 +1,4 @@ +import traceback from datetime import datetime from typing import Annotated, List, Optional @@ -51,7 +52,7 @@ def list_agents( project_id: Optional[str] = Query(None, description="Search agents by project id"), template_id: Optional[str] = Query(None, description="Search agents by template id"), base_template_id: Optional[str] = Query(None, description="Search agents by base template id"), - identifier_key: Optional[str] = Query(None, description="Search agents by identifier key"), + identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"), ): """ List all agents associated with a given user. @@ -67,7 +68,6 @@ def list_agents( "project_id": project_id, "template_id": template_id, "base_template_id": base_template_id, - "identifier_key": identifier_key, }.items() if value is not None } @@ -81,6 +81,7 @@ def list_agents( query_text=query_text, tags=tags, match_all_tags=match_all_tags, + identifier_keys=identifier_keys, **kwargs, ) return agents @@ -119,8 +120,12 @@ def create_agent( """ Create a new agent with the specified configuration. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.create_agent(agent, actor=actor) + try: + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.create_agent(agent, actor=actor) + except Exception as e: + traceback.print_exc() + raise HTTPException(status_code=500, detail=str(e)) @router.patch("/{agent_id}", response_model=AgentState, operation_id="modify_agent") diff --git a/letta/server/rest_api/routers/v1/identities.py b/letta/server/rest_api/routers/v1/identities.py index 24a136d4..5e6b6853 100644 --- a/letta/server/rest_api/routers/v1/identities.py +++ b/letta/server/rest_api/routers/v1/identities.py @@ -16,6 +16,7 @@ router = APIRouter(prefix="/identities", tags=["identities"]) def list_identities( name: Optional[str] = Query(None), project_id: Optional[str] = Query(None), + identifier_key: Optional[str] = Query(None), identity_type: Optional[IdentityType] = Query(None), before: Optional[str] = Query(None), after: Optional[str] = Query(None), @@ -30,7 +31,14 @@ def list_identities( actor = server.user_manager.get_user_or_default(user_id=user_id) identities = server.identity_manager.list_identities( - name=name, project_id=project_id, identity_type=identity_type, before=before, after=after, limit=limit, actor=actor + name=name, + project_id=project_id, + identifier_key=identifier_key, + identity_type=identity_type, + before=before, + after=after, + limit=limit, + actor=actor, ) except HTTPException: raise @@ -39,13 +47,13 @@ def list_identities( return identities -@router.get("/{identifier_key}", tags=["identities"], response_model=Identity, operation_id="get_identity_from_identifier_key") +@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity") def retrieve_identity( - identifier_key: str, + identity_id: str, server: "SyncServer" = Depends(get_letta_server), ): try: - return server.identity_manager.get_identity_from_identifier_key(identifier_key=identifier_key) + return server.identity_manager.get_identity(identity_id=identity_id) except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @@ -82,25 +90,25 @@ def upsert_identity( raise HTTPException(status_code=500, detail=f"{e}") -@router.patch("/{identifier_key}", tags=["identities"], response_model=Identity, operation_id="update_identity") +@router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity") def modify_identity( - identifier_key: str, + identity_id: str, identity: IdentityUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): try: actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.identity_manager.update_identity_by_key(identifier_key=identifier_key, identity=identity, actor=actor) + return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") -@router.delete("/{identifier_key}", tags=["identities"], operation_id="delete_identity") +@router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity") def delete_identity( - identifier_key: str, + identity_id: str, server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -108,4 +116,4 @@ def delete_identity( Delete an identity by its identifier key """ actor = server.user_manager.get_user_or_default(user_id=user_id) - server.identity_manager.delete_identity_by_key(identifier_key=identifier_key, actor=actor) + server.identity_manager.delete_identity(identity_id=identity_id, actor=actor) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d921540d..fb450d3d 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -11,6 +11,7 @@ from letta.log import get_logger from letta.orm import Agent as AgentModel from letta.orm import AgentPassage, AgentsTags from letta.orm import Block as BlockModel +from letta.orm import Identity as IdentityModel from letta.orm import Source as SourceModel from letta.orm import SourcePassage, SourcesAgents from letta.orm import Tool as ToolModel @@ -34,7 +35,6 @@ from letta.schemas.user import User as PydanticUser from letta.serialize_schemas import SerializedAgentSchema from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import ( - _process_identity, _process_relationship, _process_tags, check_supports_structured_output, @@ -138,6 +138,7 @@ class AgentManager: tool_ids=tool_ids, source_ids=agent_create.source_ids or [], tags=agent_create.tags or [], + identity_ids=agent_create.identity_ids or [], description=agent_create.description, metadata=agent_create.metadata, tool_rules=tool_rules, @@ -145,7 +146,6 @@ class AgentManager: project_id=agent_create.project_id, template_id=agent_create.template_id, base_template_id=agent_create.base_template_id, - identifier_key=agent_create.identifier_key, message_buffer_autoclear=agent_create.message_buffer_autoclear, ) @@ -203,13 +203,13 @@ class AgentManager: tool_ids: List[str], source_ids: List[str], tags: List[str], + identity_ids: List[str], description: Optional[str] = None, metadata: Optional[Dict] = None, tool_rules: Optional[List[PydanticToolRule]] = None, project_id: Optional[str] = None, template_id: Optional[str] = None, base_template_id: Optional[str] = None, - identifier_key: Optional[str] = None, message_buffer_autoclear: bool = False, ) -> PydanticAgentState: """Create a new agent.""" @@ -237,9 +237,7 @@ class AgentManager: _process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True) _process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True) _process_tags(new_agent, tags, replace=True) - if identifier_key is not None: - identity = self.identity_manager.get_identity_from_identifier_key(identifier_key) - _process_identity(new_agent, identifier_key, identity) + _process_relationship(session, new_agent, "identities", IdentityModel, identity_ids, replace=True) new_agent.create(session, actor=actor) @@ -313,9 +311,8 @@ class AgentManager: _process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True) if agent_update.tags is not None: _process_tags(agent, agent_update.tags, replace=True) - if agent_update.identifier_key is not None: - identity = self.identity_manager.get_identity_from_identifier_key(agent_update.identifier_key) - _process_identity(agent, agent_update.identifier_key, identity) + if agent_update.identity_ids is not None: + _process_relationship(session, agent, "identities", IdentityModel, agent_update.identity_ids, replace=True) # Commit and refresh the agent agent.update(session, actor=actor) @@ -333,6 +330,7 @@ class AgentManager: tags: Optional[List[str]] = None, match_all_tags: bool = False, query_text: Optional[str] = None, + identifier_keys: Optional[List[str]] = None, **kwargs, ) -> List[PydanticAgentState]: """ @@ -348,6 +346,7 @@ class AgentManager: match_all_tags=match_all_tags, organization_id=actor.organization_id if actor else None, query_text=query_text, + identifier_keys=identifier_keys, **kwargs, ) diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 098ffc52..8d99449c 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -11,7 +11,6 @@ from letta.orm.errors import NoResultFound from letta.prompts import gpt_system from letta.schemas.agent import AgentState, AgentType from letta.schemas.enums import MessageRole -from letta.schemas.identity import Identity from letta.schemas.memory import Memory from letta.schemas.message import Message, MessageCreate, TextContent from letta.schemas.tool_rule import ToolRule @@ -85,20 +84,6 @@ def _process_tags(agent: AgentModel, tags: List[str], replace=True): agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags]) -def _process_identity(agent: AgentModel, identifier_key: str, identity: Identity): - """ - Handles identity for an agent. - - Args: - agent: The AgentModel instance. - identifier_key: The identifier key of the identity to set or update. - identity: The Identity object to set or update. - """ - agent.identifier_key = identifier_key - agent.identity = identity - agent.identity_id = identity.id - - def derive_system_message(agent_type: AgentType, system: Optional[str] = None): if system is None: # TODO: don't hardcode diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 3973fa1b..53058960 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -1,6 +1,7 @@ from typing import List, Optional from fastapi import HTTPException +from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Session from letta.orm.agent import Agent as AgentModel @@ -23,6 +24,7 @@ class IdentityManager: self, name: Optional[str] = None, project_id: Optional[str] = None, + identifier_key: Optional[str] = None, identity_type: Optional[IdentityType] = None, before: Optional[str] = None, after: Optional[str] = None, @@ -33,6 +35,8 @@ class IdentityManager: filters = {"organization_id": actor.organization_id} if project_id: filters["project_id"] = project_id + if identifier_key: + filters["identifier_key"] = identifier_key if identity_type: filters["identity_type"] = identity_type identities = IdentityModel.list( @@ -46,9 +50,9 @@ class IdentityManager: return [identity.to_pydantic() for identity in identities] @enforce_types - def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity: + def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity: with self.session_maker() as session: - identity = IdentityModel.read(db_session=session, identifier_key=identifier_key) + identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) return identity.to_pydantic() @enforce_types @@ -68,44 +72,56 @@ class IdentityManager: identifier_key=identity.identifier_key, project_id=identity.project_id, organization_id=actor.organization_id, + actor=actor, ) if existing_identity is None: return self.create_identity(identity=identity, actor=actor) else: - if existing_identity.identifier_key != identity.identifier_key: - raise HTTPException(status_code=400, detail="Identifier key is an immutable field") - if existing_identity.project_id != identity.project_id: - raise HTTPException(status_code=400, detail="Project id is an immutable field") identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids) - return self.update_identity_by_key(identity.identifier_key, identity_update, actor, replace=True) + return self._update_identity( + session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True + ) @enforce_types - def update_identity_by_key( - self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False - ) -> PydanticIdentity: + def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity: with self.session_maker() as session: try: - existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key) + existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) except NoResultFound: raise HTTPException(status_code=404, detail="Identity not found") if existing_identity.organization_id != actor.organization_id: raise HTTPException(status_code=403, detail="Forbidden") - existing_identity.name = identity.name if identity.name is not None else existing_identity.name - existing_identity.identity_type = ( - identity.identity_type if identity.identity_type is not None else existing_identity.identity_type + return self._update_identity( + session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace ) - self._process_agent_relationship( - session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace - ) - existing_identity.update(session, actor=actor) - return existing_identity.to_pydantic() + + def _update_identity( + self, + session: Session, + existing_identity: IdentityModel, + identity: IdentityUpdate, + actor: PydanticUser, + replace: bool = False, + ) -> PydanticIdentity: + if identity.identifier_key is not None: + existing_identity.identifier_key = identity.identifier_key + if identity.name is not None: + existing_identity.name = identity.name + if identity.identity_type is not None: + existing_identity.identity_type = identity.identity_type + + self._process_agent_relationship( + session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace + ) + existing_identity.update(session, actor=actor) + return existing_identity.to_pydantic() @enforce_types - def delete_identity_by_key(self, identifier_key: str, actor: PydanticUser) -> None: + def delete_identity(self, identity_id: str, actor: PydanticUser) -> None: with self.session_maker() as session: - identity = IdentityModel.read(db_session=session, identifier_key=identifier_key) + identity = IdentityModel.read(db_session=session, identifier=identity_id) if identity is None: raise HTTPException(status_code=404, detail="Identity not found") if identity.organization_id != actor.organization_id: