feat: add create identity (#1064)

This commit is contained in:
Shubham Naik
2025-02-19 22:16:21 -08:00
committed by GitHub
parent 41583d7d99
commit dbb28af496
5 changed files with 95 additions and 68 deletions

View File

@@ -519,7 +519,7 @@ def _prepare_anthropic_request(
prefix_fill: bool = True,
# if true, put COT inside the tool calls instead of inside the content
put_inner_thoughts_in_kwargs: bool = False,
bedrock: bool = False
bedrock: bool = False,
) -> dict:
"""Prepare the request data for Anthropic API format."""
@@ -607,7 +607,7 @@ def _prepare_anthropic_request(
# NOTE: cannot prefill with tools for opus:
# Your API request included an `assistant` message in the final position, which would pre-fill the `assistant` response. When using tools with "claude-3-opus-20240229"
if prefix_fill and not put_inner_thoughts_in_kwargs and "opus" not in data["model"]:
if not bedrock: # not support for bedrock
if not bedrock: # not support for bedrock
data["messages"].append(
# Start the thinking process for the assistant
{"role": "assistant", "content": f"<{inner_thoughts_xml_tag}>"},

View File

@@ -1,6 +1,7 @@
import uuid
from typing import List, Optional
from sqlalchemy import UniqueConstraint
from sqlalchemy import String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
@@ -15,6 +16,7 @@ class Identity(SqlalchemyBase, OrganizationMixin):
__pydantic_model__ = PydanticIdentity
__table_args__ = (UniqueConstraint("identifier_key", "project_id", "organization_id", name="unique_identifier_pid_org_id"),)
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"identity-{uuid.uuid4()}")
identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.")
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.")
identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.")
@@ -22,4 +24,16 @@ class Identity(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities")
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="identity")
agents: Mapped[List["Agent"]] = relationship("Agent", lazy="selectin", back_populates="identity")
def to_pydantic(self) -> PydanticIdentity:
state = {
"id": self.id,
"identifier_key": self.identifier_key,
"name": self.name,
"identity_type": self.identity_type,
"project_id": self.project_id,
"agents": [agent.to_pydantic() for agent in self.agents],
}
return self.__pydantic_model__(**state)

View File

@@ -22,7 +22,7 @@ class IdentityBase(LettaBase):
class Identity(IdentityBase):
id: str = Field(..., description="The internal id of the identity.")
id: str = IdentityBase.generate_id_field()
identifier_key: str = Field(..., description="External, user-generated identifier key of the identity.")
name: str = Field(..., description="The name of the identity.")
identity_type: IdentityType = Field(..., description="The type of the identity.")

View File

@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from letta.orm.errors import NoResultFound
from letta.schemas.identity import Identity, IdentityCreate, IdentityType, IdentityUpdate
from letta.server.rest_api.utils import get_letta_server
@@ -44,7 +45,7 @@ def retrieve_identity(
server: "SyncServer" = Depends(get_letta_server),
):
try:
return server.identity_manager.get_identity_by_identifier_key(identifier_key=identifier_key)
return server.identity_manager.get_identity_from_identifier_key(identifier_key=identifier_key)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -56,8 +57,13 @@ def create_identity(
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
project_slug: Optional[str] = Header(None, alias="project-slug"), # Only handled by next js middleware
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.identity_manager.create_identity(identity=identity, actor=actor)
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.identity_manager.create_identity(identity=identity, actor=actor)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.put("/", tags=["identities"], response_model=Identity, operation_id="upsert_identity")
@@ -67,8 +73,13 @@ def upsert_identity(
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
project_slug: Optional[str] = Header(None, alias="project-slug"), # Only handled by next js middleware
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.identity_manager.upsert_identity(identity=identity, actor=actor)
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.identity_manager.upsert_identity(identity=identity, actor=actor)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.patch("/{identifier_key}", tags=["identities"], response_model=Identity, operation_id="update_identity")
@@ -78,8 +89,13 @@ def modify_identity(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.identity_manager.update_identity_by_key(identifier_key=identifier_key, identity=identity, actor=actor)
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.identity_manager.update_identity_by_key(identifier_key=identifier_key, identity=identity, actor=actor)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.delete("/{identifier_key}", tags=["identities"], operation_id="delete_identity")

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional
from fastapi import HTTPException
from sqlalchemy.orm import Session
@@ -54,57 +54,52 @@ class IdentityManager:
@enforce_types
def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
agents = self._get_agents_from_ids(session=session, agent_ids=identity.agent_ids, actor=actor)
identity = IdentityModel.create(
db_session=session,
name=identity.name,
identifier_key=identity.identifier_key,
identity_type=identity.identity_type,
project_id=identity.project_id,
organization_id=actor.organization_id,
agents=agents,
)
return identity.to_pydantic()
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids"}, exclude_unset=True))
new_identity.organization_id = actor.organization_id
self._process_agent_relationship(session=session, identity=new_identity, agent_ids=identity.agent_ids, allow_partial=False)
new_identity.create(session, actor=actor)
return new_identity.to_pydantic()
@enforce_types
def upsert_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
if existing_identity is None:
identity = self.create_identity(identity=identity, actor=actor)
else:
if existing_identity.identifier_key != identity.identifier_key:
raise HTTPException(status_code=400, detail="Identifier key is an immutable field")
if existing_identity.project_id != identity.project_id:
raise HTTPException(status_code=400, detail="Project id is an immutable field")
if existing_identity.organization_id != identity.organization_id:
raise HTTPException(status_code=400, detail="Organization id is an immutable field")
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
identity = self.update_identity_by_key(identity.identifier_key, identity_update, actor)
identity.commit(session)
return identity.to_pydantic()
existing_identity = IdentityModel.read(
db_session=session,
identifier_key=identity.identifier_key,
project_id=identity.project_id,
organization_id=actor.organization_id,
)
if existing_identity is None:
return self.create_identity(identity=identity, actor=actor)
else:
if existing_identity.identifier_key != identity.identifier_key:
raise HTTPException(status_code=400, detail="Identifier key is an immutable field")
if existing_identity.project_id != identity.project_id:
raise HTTPException(status_code=400, detail="Project id is an immutable field")
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
return self.update_identity_by_key(identity.identifier_key, identity_update, actor, replace=True)
@enforce_types
def update_identity_by_key(self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser) -> PydanticIdentity:
def update_identity_by_key(
self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
) -> PydanticIdentity:
with self.session_maker() as session:
try:
existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
except NoResultFound:
raise HTTPException(status_code=404, detail="Identity not found")
if identity.organization_id != existing_identity.organization_id or identity.organization_id != actor.organization_id:
if existing_identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
agents = None
if identity.agent_ids:
agents = self._get_agents_from_ids(session=session, agent_ids=identity.agent_ids, actor=actor)
existing_identity.name = identity.name if identity.name is not None else existing_identity.name
existing_identity.identity_type = (
identity.identity_type if identity.identity_type is not None else existing_identity.identity_type
)
existing_identity.agents = agents if agents is not None else existing_identity.agents
existing_identity.commit(session)
self._process_agent_relationship(
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
)
existing_identity.update(session, actor=actor)
return existing_identity.to_pydantic()
@enforce_types
@@ -116,28 +111,30 @@ class IdentityManager:
if identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
session.delete(identity)
session.commit()
def _get_agents_from_ids(self, session: Session, agent_ids: list[str], actor: PydanticUser) -> list[AgentModel]:
"""Helper method to get agents from their IDs and verify permissions.
def _process_agent_relationship(
self, session: Session, identity: IdentityModel, agent_ids: List[str], allow_partial=False, replace=True
):
current_relationship = getattr(identity, "agents", [])
if not agent_ids:
if replace:
setattr(identity, "agents", [])
return
Args:
session: The database session
agent_ids: List of agent IDs to fetch
actor: The user making the request
# Retrieve models for the provided IDs
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
Returns:
List of agent models
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(agent_ids):
missing = set(agent_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in agents: {missing}")
Raises:
HTTPException: If agents not found or user doesn't have permission
"""
agents = AgentModel.list(db_session=session, ids=agent_ids)
if len(agents) != len(agent_ids):
found_ids = {agent.id for agent in agents}
missing_ids = [id for id in agent_ids if id not in found_ids]
raise HTTPException(status_code=404, detail=f"Agents not found: {', '.join(missing_ids)}")
if any(agent.organization_id != actor.organization_id for agent in agents):
raise HTTPException(status_code=403, detail="Forbidden")
return agents
if replace:
# Replace the relationship
setattr(identity, "agents", found_items)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}
new_items = [item for item in found_items if item.id not in current_ids]
current_relationship.extend(new_items)