feat: make identities many to many (#1085)

This commit is contained in:
cthomas
2025-02-20 16:33:24 -08:00
committed by GitHub
parent afbb5af30b
commit 31130a6d28
13 changed files with 243 additions and 83 deletions

View File

@@ -4,6 +4,7 @@ from letta.orm.base import Base
from letta.orm.block import Block
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.identities_agents import IdentitiesAgents
from letta.orm.identity import Identity
from letta.orm.job import Job
from letta.orm.job_messages import JobMessage

View File

@@ -1,7 +1,7 @@
import uuid
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, Boolean, ForeignKey, Index, String
from sqlalchemy import JSON, Boolean, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block
@@ -61,14 +61,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the template the agent belongs to.")
base_template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The base template id of the agent.")
# Identity
identity_id: Mapped[Optional[str]] = mapped_column(
String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to."
)
identifier_key: Mapped[Optional[str]] = mapped_column(
String, nullable=True, doc="The identifier key of the identity the agent belongs to."
)
# Tool rules
tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.")
@@ -79,7 +71,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
identity: Mapped["Identity"] = relationship("Identity", back_populates="agents")
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
"AgentEnvironmentVariable",
back_populates="agent",
@@ -130,7 +121,13 @@ class Agent(SqlalchemyBase, OrganizationMixin):
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
doc="All passages derived created by this agent.",
)
identity: Mapped[Optional["Identity"]] = relationship("Identity", back_populates="agents")
identities: Mapped[List["Identity"]] = relationship(
"Identity",
secondary="identities_agents",
lazy="selectin",
back_populates="agents",
passive_deletes=True,
)
def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart"""
@@ -160,6 +157,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"project_id": self.project_id,
"template_id": self.template_id,
"base_template_id": self.base_template_id,
"identity_ids": [identity.id for identity in self.identities],
"message_buffer_autoclear": self.message_buffer_autoclear,
}

View File

@@ -0,0 +1,13 @@
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
class IdentitiesAgents(Base):
"""Identities may have one or many agents associated with them."""
__tablename__ = "identities_agents"
identity_id: Mapped[str] = mapped_column(String, ForeignKey("identities.id", ondelete="CASCADE"), primary_key=True)
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)

View File

@@ -2,11 +2,13 @@ import uuid
from typing import List, Optional
from sqlalchemy import String, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.identity import Identity as PydanticIdentity
from letta.schemas.identity import IdentityProperty
class Identity(SqlalchemyBase, OrganizationMixin):
@@ -14,17 +16,35 @@ class Identity(SqlalchemyBase, OrganizationMixin):
__tablename__ = "identities"
__pydantic_model__ = PydanticIdentity
__table_args__ = (UniqueConstraint("identifier_key", "project_id", "organization_id", name="unique_identifier_pid_org_id"),)
__table_args__ = (
UniqueConstraint(
"identifier_key",
"project_id",
"organization_id",
name="unique_identifier_without_project",
postgresql_nulls_not_distinct=True,
),
)
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"identity-{uuid.uuid4()}")
identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.")
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.")
identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.")
project_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The project id of the identity.")
properties: Mapped[List["IdentityProperty"]] = mapped_column(
JSONB, nullable=False, default=list, doc="List of properties associated with the identity"
)
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities")
agents: Mapped[List["Agent"]] = relationship("Agent", lazy="selectin", back_populates="identity")
agents: Mapped[List["Agent"]] = relationship(
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
)
@property
def agent_ids(self) -> List[str]:
"""Get just the agent IDs without loading the full agent objects"""
return [agent.id for agent in self.agents]
def to_pydantic(self) -> PydanticIdentity:
state = {
@@ -33,7 +53,8 @@ class Identity(SqlalchemyBase, OrganizationMixin):
"name": self.name,
"identity_type": self.identity_type,
"project_id": self.project_id,
"agents": [agent.to_pydantic() for agent in self.agents],
"agent_ids": self.agent_ids,
"organization_id": self.organization_id,
"properties": self.properties,
}
return self.__pydantic_model__(**state)
return PydanticIdentity(**state)

View File

@@ -68,6 +68,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
access_type: AccessType = AccessType.ORGANIZATION,
join_model: Optional[Base] = None,
join_conditions: Optional[Union[Tuple, List]] = None,
identifier_keys: Optional[List[str]] = None,
**kwargs,
) -> List["SqlalchemyBase"]:
"""
@@ -143,6 +144,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# Group by primary key and all necessary columns to avoid JSON comparison
query = query.group_by(cls.id)
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
# Apply filtering logic from kwargs
for key, value in kwargs.items():
if "." in key: