diff --git a/alembic/versions/a113caac453e_add_identities_table.py b/alembic/versions/a113caac453e_add_identities_table.py new file mode 100644 index 00000000..7c4d4140 --- /dev/null +++ b/alembic/versions/a113caac453e_add_identities_table.py @@ -0,0 +1,66 @@ +"""add identities table + +Revision ID: a113caac453e +Revises: 7980d239ea08 +Create Date: 2025-02-14 09:58:18.227122 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a113caac453e" +down_revision: Union[str, None] = "7980d239ea08" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create identities table + op.create_table( + "identities", + sa.Column("id", sa.String(), nullable=False), + sa.Column("identifier_key", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("identity_type", sa.String(), nullable=False), + sa.Column("project_id", sa.String(), nullable=True), + # From OrganizationMixin + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + # Foreign key to organizations + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + # Composite unique constraint + sa.UniqueConstraint( + "identifier_key", + "project_id", + "organization_id", + name="unique_identifier_pid_org_id", + ), + sa.PrimaryKeyConstraint("id"), + ) + + # Add identity_id column to agents table + op.add_column("agents", sa.Column("identity_id", sa.String(), nullable=True)) + + # Add foreign key constraint + op.create_foreign_key("fk_agents_identity_id", "agents", "identities", ["identity_id"], ["id"], ondelete="CASCADE") + + +def downgrade() -> None: + # First remove the foreign key constraint and column from agents + op.drop_constraint("fk_agents_identity_id", "agents", type_="foreignkey") + op.drop_column("agents", "identity_id") + + # Then drop the table + op.drop_table("identities") diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 5898dd80..28feb237 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -4,6 +4,7 @@ from letta.orm.base import Base from letta.orm.block import Block from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata +from letta.orm.identity import Identity from letta.orm.job import Job from letta.orm.job_messages import JobMessage from letta.orm.message import Message diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 07b3917b..3555dd7f 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -1,7 +1,7 @@ import uuid from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, Boolean, Index, String +from sqlalchemy import JSON, Boolean, ForeignKey, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.block import Block @@ -19,6 +19,7 @@ from letta.schemas.tool_rule import TerminalToolRule, ToolRule if TYPE_CHECKING: from letta.orm.agents_tags import AgentsTags + from letta.orm.identity import Identity from letta.orm.organization import Organization from letta.orm.source import Source from letta.orm.tool import Tool @@ -59,6 +60,11 @@ class Agent(SqlalchemyBase, OrganizationMixin): template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the template the agent belongs to.") base_template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The base template id of the agent.") + # Identity + identity_id: Mapped[Optional[str]] = mapped_column( + String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to." + ) + # Tool rules tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.") @@ -119,6 +125,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship doc="All passages derived created by this agent.", ) + identity: Mapped[Optional["Identity"]] = relationship("Identity", back_populates="agents") def to_pydantic(self) -> PydanticAgentState: """converts to the basic pydantic model counterpart""" diff --git a/letta/orm/identity.py b/letta/orm/identity.py new file mode 100644 index 00000000..52d7b8ab --- /dev/null +++ b/letta/orm/identity.py @@ -0,0 +1,25 @@ +from typing import List, Optional + +from sqlalchemy import UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.identity import Identity as PydanticIdentity + + +class Identity(SqlalchemyBase, OrganizationMixin): + """Identity ORM class""" + + __tablename__ = "identities" + __pydantic_model__ = PydanticIdentity + __table_args__ = (UniqueConstraint("identifier_key", "project_id", "organization_id", name="unique_identifier_pid_org_id"),) + + identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.") + name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.") + identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.") + project_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The project id of the identity.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="identities") + agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="identity") diff --git a/letta/orm/organization.py b/letta/orm/organization.py index cef5adbd..fc8dcfc7 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from letta.orm.agent import Agent from letta.orm.file import FileMetadata + from letta.orm.identity import Identity from letta.orm.provider import Provider from letta.orm.sandbox_config import AgentEnvironmentVariable from letta.orm.tool import Tool @@ -47,6 +48,7 @@ class Organization(SqlalchemyBase): ) agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan") providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan") + identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan") @property def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: diff --git a/letta/schemas/identity.py b/letta/schemas/identity.py new file mode 100644 index 00000000..00db9e87 --- /dev/null +++ b/letta/schemas/identity.py @@ -0,0 +1,30 @@ +from enum import Enum +from typing import List, Optional + +from pydantic import Field + +from letta.schemas.agent import AgentState +from letta.schemas.letta_base import LettaBase + + +class IdentityType(str, Enum): + """ + Enum to represent the type of the identity. + """ + + org = "org" + user = "user" + other = "other" + + +class IdentityBase(LettaBase): + __id_prefix__ = "identity" + + +class Identity(IdentityBase): + id: str = Field(..., description="The internal id of the identity.") + identifier_key: str = Field(..., description="External, user-generated identifier key of the identity.") + name: str = Field(..., description="The name of the identity.") + identity_type: IdentityType = Field(..., description="The type of the identity.") + project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.") + agents: List[AgentState] = Field(..., description="The ids of the agents associated with the identity.")