feat: make identities many to many (#1085)
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
"""update identities unique constraint and properties
|
||||
|
||||
Revision ID: 549eff097c71
|
||||
Revises: a3047a624130
|
||||
Create Date: 2025-02-20 09:53:42.743105
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "549eff097c71"
|
||||
down_revision: Union[str, None] = "a3047a624130"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Update unique constraint on identities table
|
||||
op.drop_constraint("unique_identifier_pid_org_id", "identities", type_="unique")
|
||||
op.create_unique_constraint(
|
||||
"unique_identifier_without_project",
|
||||
"identities",
|
||||
["identifier_key", "project_id", "organization_id"],
|
||||
postgresql_nulls_not_distinct=True,
|
||||
)
|
||||
|
||||
# Add properties column to identities table
|
||||
op.add_column("identities", sa.Column("properties", postgresql.JSONB, nullable=False, server_default="[]"))
|
||||
|
||||
# Create identities_agents table for many-to-many relationship
|
||||
op.create_table(
|
||||
"identities_agents",
|
||||
sa.Column("identity_id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["identity_id"], ["identities.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("identity_id", "agent_id"),
|
||||
)
|
||||
|
||||
# Migrate existing relationships
|
||||
# First, get existing relationships where identity_id is not null
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO identities_agents (identity_id, agent_id)
|
||||
SELECT DISTINCT identity_id, id as agent_id
|
||||
FROM agents
|
||||
WHERE identity_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Remove old identity_id column from agents
|
||||
op.drop_column("agents", "identity_id")
|
||||
op.drop_column("agents", "identifier_key")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Add back the old columns to agents
|
||||
op.add_column("agents", sa.Column("identity_id", sa.String(), nullable=True))
|
||||
op.add_column("agents", sa.Column("identifier_key", sa.String(), nullable=True))
|
||||
|
||||
# Migrate relationships back
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE agents a
|
||||
SET identity_id = ia.identity_id
|
||||
FROM identities_agents ia
|
||||
WHERE a.id = ia.agent_id
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the many-to-many table
|
||||
op.drop_table("identities_agents")
|
||||
|
||||
# Drop properties column
|
||||
op.drop_column("identities", "properties")
|
||||
|
||||
# Restore old unique constraint
|
||||
op.drop_constraint("unique_identifier_without_project", "identities", type_="unique")
|
||||
op.create_unique_constraint("unique_identifier_pid_org_id", "identities", ["identifier_key", "project_id", "organization_id"])
|
||||
# ### end Alembic commands ###
|
||||
@@ -4,6 +4,7 @@ from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.identities_agents import IdentitiesAgents
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.job_messages import JobMessage
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, Boolean, ForeignKey, Index, String
|
||||
from sqlalchemy import JSON, Boolean, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.block import Block
|
||||
@@ -61,14 +61,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the template the agent belongs to.")
|
||||
base_template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The base template id of the agent.")
|
||||
|
||||
# Identity
|
||||
identity_id: Mapped[Optional[str]] = mapped_column(
|
||||
String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to."
|
||||
)
|
||||
identifier_key: Mapped[Optional[str]] = mapped_column(
|
||||
String, nullable=True, doc="The identifier key of the identity the agent belongs to."
|
||||
)
|
||||
|
||||
# Tool rules
|
||||
tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.")
|
||||
|
||||
@@ -79,7 +71,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
|
||||
identity: Mapped["Identity"] = relationship("Identity", back_populates="agents")
|
||||
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
|
||||
"AgentEnvironmentVariable",
|
||||
back_populates="agent",
|
||||
@@ -130,7 +121,13 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
|
||||
doc="All passages derived created by this agent.",
|
||||
)
|
||||
identity: Mapped[Optional["Identity"]] = relationship("Identity", back_populates="agents")
|
||||
identities: Mapped[List["Identity"]] = relationship(
|
||||
"Identity",
|
||||
secondary="identities_agents",
|
||||
lazy="selectin",
|
||||
back_populates="agents",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> PydanticAgentState:
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
@@ -160,6 +157,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"project_id": self.project_id,
|
||||
"template_id": self.template_id,
|
||||
"base_template_id": self.base_template_id,
|
||||
"identity_ids": [identity.id for identity in self.identities],
|
||||
"message_buffer_autoclear": self.message_buffer_autoclear,
|
||||
}
|
||||
|
||||
|
||||
13
letta/orm/identities_agents.py
Normal file
13
letta/orm/identities_agents.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class IdentitiesAgents(Base):
|
||||
"""Identities may have one or many agents associated with them."""
|
||||
|
||||
__tablename__ = "identities_agents"
|
||||
|
||||
identity_id: Mapped[str] = mapped_column(String, ForeignKey("identities.id", ondelete="CASCADE"), primary_key=True)
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
|
||||
@@ -2,11 +2,13 @@ import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.identity import Identity as PydanticIdentity
|
||||
from letta.schemas.identity import IdentityProperty
|
||||
|
||||
|
||||
class Identity(SqlalchemyBase, OrganizationMixin):
|
||||
@@ -14,17 +16,35 @@ class Identity(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
__tablename__ = "identities"
|
||||
__pydantic_model__ = PydanticIdentity
|
||||
__table_args__ = (UniqueConstraint("identifier_key", "project_id", "organization_id", name="unique_identifier_pid_org_id"),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"identifier_key",
|
||||
"project_id",
|
||||
"organization_id",
|
||||
name="unique_identifier_without_project",
|
||||
postgresql_nulls_not_distinct=True,
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"identity-{uuid.uuid4()}")
|
||||
identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.")
|
||||
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.")
|
||||
identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.")
|
||||
project_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The project id of the identity.")
|
||||
properties: Mapped[List["IdentityProperty"]] = mapped_column(
|
||||
JSONB, nullable=False, default=list, doc="List of properties associated with the identity"
|
||||
)
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities")
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", lazy="selectin", back_populates="identity")
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
|
||||
)
|
||||
|
||||
@property
|
||||
def agent_ids(self) -> List[str]:
|
||||
"""Get just the agent IDs without loading the full agent objects"""
|
||||
return [agent.id for agent in self.agents]
|
||||
|
||||
def to_pydantic(self) -> PydanticIdentity:
|
||||
state = {
|
||||
@@ -33,7 +53,8 @@ class Identity(SqlalchemyBase, OrganizationMixin):
|
||||
"name": self.name,
|
||||
"identity_type": self.identity_type,
|
||||
"project_id": self.project_id,
|
||||
"agents": [agent.to_pydantic() for agent in self.agents],
|
||||
"agent_ids": self.agent_ids,
|
||||
"organization_id": self.organization_id,
|
||||
"properties": self.properties,
|
||||
}
|
||||
|
||||
return self.__pydantic_model__(**state)
|
||||
return PydanticIdentity(**state)
|
||||
|
||||
@@ -68,6 +68,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
join_model: Optional[Base] = None,
|
||||
join_conditions: Optional[Union[Tuple, List]] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> List["SqlalchemyBase"]:
|
||||
"""
|
||||
@@ -143,6 +144,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
# Group by primary key and all necessary columns to avoid JSON comparison
|
||||
query = query.group_by(cls.id)
|
||||
|
||||
if identifier_keys and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
|
||||
# Apply filtering logic from kwargs
|
||||
for key, value in kwargs.items():
|
||||
if "." in key:
|
||||
|
||||
@@ -83,9 +83,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
|
||||
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
|
||||
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
|
||||
|
||||
# Identity
|
||||
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
|
||||
identity_ids: List[str] = Field([], description="The ids of the identities associated with this agent.")
|
||||
|
||||
# An advanced configuration that makes it so this agent does not remember any previous messages
|
||||
message_buffer_autoclear: bool = Field(
|
||||
@@ -161,7 +159,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
|
||||
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
|
||||
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
|
||||
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
|
||||
identity_ids: Optional[List[str]] = Field(None, description="The ids of the identities associated with this agent.")
|
||||
message_buffer_autoclear: bool = Field(
|
||||
False,
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
@@ -236,7 +234,7 @@ class UpdateAgent(BaseModel):
|
||||
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
|
||||
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
|
||||
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
|
||||
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
|
||||
identity_ids: Optional[List[str]] = Field(None, description="The ids of the identities associated with this agent.")
|
||||
message_buffer_autoclear: Optional[bool] = Field(
|
||||
None,
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -17,17 +16,38 @@ class IdentityType(str, Enum):
|
||||
other = "other"
|
||||
|
||||
|
||||
class IdentityPropertyType(str, Enum):
|
||||
"""
|
||||
Enum to represent the type of the identity property.
|
||||
"""
|
||||
|
||||
string = "string"
|
||||
number = "number"
|
||||
boolean = "boolean"
|
||||
json = "json"
|
||||
|
||||
|
||||
class IdentityBase(LettaBase):
|
||||
__id_prefix__ = "identity"
|
||||
|
||||
|
||||
class IdentityProperty(LettaBase):
|
||||
"""A property of an identity"""
|
||||
|
||||
key: str = Field(..., description="The key of the property")
|
||||
value: Union[str, int, float, bool, dict] = Field(..., description="The value of the property")
|
||||
type: IdentityPropertyType = Field(..., description="The type of the property")
|
||||
|
||||
|
||||
class Identity(IdentityBase):
|
||||
id: str = IdentityBase.generate_id_field()
|
||||
identifier_key: str = Field(..., description="External, user-generated identifier key of the identity.")
|
||||
name: str = Field(..., description="The name of the identity.")
|
||||
identity_type: IdentityType = Field(..., description="The type of the identity.")
|
||||
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
|
||||
agents: List[AgentState] = Field(..., description="The agents associated with the identity.")
|
||||
agent_ids: List[str] = Field(..., description="The IDs of the agents associated with the identity.")
|
||||
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
||||
properties: List[IdentityProperty] = Field(default_factory=list, description="List of properties associated with the identity")
|
||||
|
||||
|
||||
class IdentityCreate(LettaBase):
|
||||
@@ -36,9 +56,12 @@ class IdentityCreate(LettaBase):
|
||||
identity_type: IdentityType = Field(..., description="The type of the identity.")
|
||||
project_id: Optional[str] = Field(None, description="The project id of the identity, if applicable.")
|
||||
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.")
|
||||
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
|
||||
|
||||
|
||||
class IdentityUpdate(LettaBase):
|
||||
identifier_key: Optional[str] = Field(None, description="External, user-generated identifier key of the identity.")
|
||||
name: Optional[str] = Field(None, description="The name of the identity.")
|
||||
identity_type: Optional[IdentityType] = Field(None, description="The type of the identity.")
|
||||
agent_ids: Optional[List[str]] = Field(None, description="The agent ids that are associated with the identity.")
|
||||
properties: Optional[List[IdentityProperty]] = Field(None, description="List of properties associated with the identity.")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
@@ -51,7 +52,7 @@ def list_agents(
|
||||
project_id: Optional[str] = Query(None, description="Search agents by project id"),
|
||||
template_id: Optional[str] = Query(None, description="Search agents by template id"),
|
||||
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
|
||||
identifier_key: Optional[str] = Query(None, description="Search agents by identifier key"),
|
||||
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
@@ -67,7 +68,6 @@ def list_agents(
|
||||
"project_id": project_id,
|
||||
"template_id": template_id,
|
||||
"base_template_id": base_template_id,
|
||||
"identifier_key": identifier_key,
|
||||
}.items()
|
||||
if value is not None
|
||||
}
|
||||
@@ -81,6 +81,7 @@ def list_agents(
|
||||
query_text=query_text,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
identifier_keys=identifier_keys,
|
||||
**kwargs,
|
||||
)
|
||||
return agents
|
||||
@@ -119,8 +120,12 @@ def create_agent(
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.create_agent(agent, actor=actor)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{agent_id}", response_model=AgentState, operation_id="modify_agent")
|
||||
|
||||
@@ -16,6 +16,7 @@ router = APIRouter(prefix="/identities", tags=["identities"])
|
||||
def list_identities(
|
||||
name: Optional[str] = Query(None),
|
||||
project_id: Optional[str] = Query(None),
|
||||
identifier_key: Optional[str] = Query(None),
|
||||
identity_type: Optional[IdentityType] = Query(None),
|
||||
before: Optional[str] = Query(None),
|
||||
after: Optional[str] = Query(None),
|
||||
@@ -30,7 +31,14 @@ def list_identities(
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
identities = server.identity_manager.list_identities(
|
||||
name=name, project_id=project_id, identity_type=identity_type, before=before, after=after, limit=limit, actor=actor
|
||||
name=name,
|
||||
project_id=project_id,
|
||||
identifier_key=identifier_key,
|
||||
identity_type=identity_type,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -39,13 +47,13 @@ def list_identities(
|
||||
return identities
|
||||
|
||||
|
||||
@router.get("/{identifier_key}", tags=["identities"], response_model=Identity, operation_id="get_identity_from_identifier_key")
|
||||
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
|
||||
def retrieve_identity(
|
||||
identifier_key: str,
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
try:
|
||||
return server.identity_manager.get_identity_from_identifier_key(identifier_key=identifier_key)
|
||||
return server.identity_manager.get_identity(identity_id=identity_id)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -82,25 +90,25 @@ def upsert_identity(
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
|
||||
@router.patch("/{identifier_key}", tags=["identities"], response_model=Identity, operation_id="update_identity")
|
||||
@router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity")
|
||||
def modify_identity(
|
||||
identifier_key: str,
|
||||
identity_id: str,
|
||||
identity: IdentityUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.identity_manager.update_identity_by_key(identifier_key=identifier_key, identity=identity, actor=actor)
|
||||
return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
|
||||
@router.delete("/{identifier_key}", tags=["identities"], operation_id="delete_identity")
|
||||
@router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity")
|
||||
def delete_identity(
|
||||
identifier_key: str,
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
@@ -108,4 +116,4 @@ def delete_identity(
|
||||
Delete an identity by its identifier key
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
server.identity_manager.delete_identity_by_key(identifier_key=identifier_key, actor=actor)
|
||||
server.identity_manager.delete_identity(identity_id=identity_id, actor=actor)
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import AgentPassage, AgentsTags
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import Identity as IdentityModel
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import SourcePassage, SourcesAgents
|
||||
from letta.orm import Tool as ToolModel
|
||||
@@ -34,7 +35,6 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import SerializedAgentSchema
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_process_identity,
|
||||
_process_relationship,
|
||||
_process_tags,
|
||||
check_supports_structured_output,
|
||||
@@ -138,6 +138,7 @@ class AgentManager:
|
||||
tool_ids=tool_ids,
|
||||
source_ids=agent_create.source_ids or [],
|
||||
tags=agent_create.tags or [],
|
||||
identity_ids=agent_create.identity_ids or [],
|
||||
description=agent_create.description,
|
||||
metadata=agent_create.metadata,
|
||||
tool_rules=tool_rules,
|
||||
@@ -145,7 +146,6 @@ class AgentManager:
|
||||
project_id=agent_create.project_id,
|
||||
template_id=agent_create.template_id,
|
||||
base_template_id=agent_create.base_template_id,
|
||||
identifier_key=agent_create.identifier_key,
|
||||
message_buffer_autoclear=agent_create.message_buffer_autoclear,
|
||||
)
|
||||
|
||||
@@ -203,13 +203,13 @@ class AgentManager:
|
||||
tool_ids: List[str],
|
||||
source_ids: List[str],
|
||||
tags: List[str],
|
||||
identity_ids: List[str],
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
tool_rules: Optional[List[PydanticToolRule]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
base_template_id: Optional[str] = None,
|
||||
identifier_key: Optional[str] = None,
|
||||
message_buffer_autoclear: bool = False,
|
||||
) -> PydanticAgentState:
|
||||
"""Create a new agent."""
|
||||
@@ -237,9 +237,7 @@ class AgentManager:
|
||||
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
|
||||
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
|
||||
_process_tags(new_agent, tags, replace=True)
|
||||
if identifier_key is not None:
|
||||
identity = self.identity_manager.get_identity_from_identifier_key(identifier_key)
|
||||
_process_identity(new_agent, identifier_key, identity)
|
||||
_process_relationship(session, new_agent, "identities", IdentityModel, identity_ids, replace=True)
|
||||
|
||||
new_agent.create(session, actor=actor)
|
||||
|
||||
@@ -313,9 +311,8 @@ class AgentManager:
|
||||
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
|
||||
if agent_update.tags is not None:
|
||||
_process_tags(agent, agent_update.tags, replace=True)
|
||||
if agent_update.identifier_key is not None:
|
||||
identity = self.identity_manager.get_identity_from_identifier_key(agent_update.identifier_key)
|
||||
_process_identity(agent, agent_update.identifier_key, identity)
|
||||
if agent_update.identity_ids is not None:
|
||||
_process_relationship(session, agent, "identities", IdentityModel, agent_update.identity_ids, replace=True)
|
||||
|
||||
# Commit and refresh the agent
|
||||
agent.update(session, actor=actor)
|
||||
@@ -333,6 +330,7 @@ class AgentManager:
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
query_text: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
@@ -348,6 +346,7 @@ class AgentManager:
|
||||
match_all_tags=match_all_tags,
|
||||
organization_id=actor.organization_id if actor else None,
|
||||
query_text=query_text,
|
||||
identifier_keys=identifier_keys,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.identity import Identity
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate, TextContent
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
@@ -85,20 +84,6 @@ def _process_tags(agent: AgentModel, tags: List[str], replace=True):
|
||||
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
|
||||
|
||||
|
||||
def _process_identity(agent: AgentModel, identifier_key: str, identity: Identity):
|
||||
"""
|
||||
Handles identity for an agent.
|
||||
|
||||
Args:
|
||||
agent: The AgentModel instance.
|
||||
identifier_key: The identifier key of the identity to set or update.
|
||||
identity: The Identity object to set or update.
|
||||
"""
|
||||
agent.identifier_key = identifier_key
|
||||
agent.identity = identity
|
||||
agent.identity_id = identity.id
|
||||
|
||||
|
||||
def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
|
||||
if system is None:
|
||||
# TODO: don't hardcode
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
@@ -23,6 +24,7 @@ class IdentityManager:
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
identifier_key: Optional[str] = None,
|
||||
identity_type: Optional[IdentityType] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
@@ -33,6 +35,8 @@ class IdentityManager:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
if identifier_key:
|
||||
filters["identifier_key"] = identifier_key
|
||||
if identity_type:
|
||||
filters["identity_type"] = identity_type
|
||||
identities = IdentityModel.list(
|
||||
@@ -46,9 +50,9 @@ class IdentityManager:
|
||||
return [identity.to_pydantic() for identity in identities]
|
||||
|
||||
@enforce_types
|
||||
def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity:
|
||||
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
return identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -68,34 +72,46 @@ class IdentityManager:
|
||||
identifier_key=identity.identifier_key,
|
||||
project_id=identity.project_id,
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if existing_identity is None:
|
||||
return self.create_identity(identity=identity, actor=actor)
|
||||
else:
|
||||
if existing_identity.identifier_key != identity.identifier_key:
|
||||
raise HTTPException(status_code=400, detail="Identifier key is an immutable field")
|
||||
if existing_identity.project_id != identity.project_id:
|
||||
raise HTTPException(status_code=400, detail="Project id is an immutable field")
|
||||
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
|
||||
return self.update_identity_by_key(identity.identifier_key, identity_update, actor, replace=True)
|
||||
return self._update_identity(
|
||||
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
def update_identity_by_key(
|
||||
self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
|
||||
) -> PydanticIdentity:
|
||||
def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
|
||||
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
if existing_identity.organization_id != actor.organization_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
existing_identity.name = identity.name if identity.name is not None else existing_identity.name
|
||||
existing_identity.identity_type = (
|
||||
identity.identity_type if identity.identity_type is not None else existing_identity.identity_type
|
||||
return self._update_identity(
|
||||
session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
|
||||
)
|
||||
|
||||
def _update_identity(
|
||||
self,
|
||||
session: Session,
|
||||
existing_identity: IdentityModel,
|
||||
identity: IdentityUpdate,
|
||||
actor: PydanticUser,
|
||||
replace: bool = False,
|
||||
) -> PydanticIdentity:
|
||||
if identity.identifier_key is not None:
|
||||
existing_identity.identifier_key = identity.identifier_key
|
||||
if identity.name is not None:
|
||||
existing_identity.name = identity.name
|
||||
if identity.identity_type is not None:
|
||||
existing_identity.identity_type = identity.identity_type
|
||||
|
||||
self._process_agent_relationship(
|
||||
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
|
||||
)
|
||||
@@ -103,9 +119,9 @@ class IdentityManager:
|
||||
return existing_identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_identity_by_key(self, identifier_key: str, actor: PydanticUser) -> None:
|
||||
def delete_identity(self, identity_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id)
|
||||
if identity is None:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
if identity.organization_id != actor.organization_id:
|
||||
|
||||
Reference in New Issue
Block a user