diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 09f6cd86..b8151e7a 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -519,7 +519,7 @@ def _prepare_anthropic_request( prefix_fill: bool = True, # if true, put COT inside the tool calls instead of inside the content put_inner_thoughts_in_kwargs: bool = False, - bedrock: bool = False + bedrock: bool = False, ) -> dict: """Prepare the request data for Anthropic API format.""" @@ -607,7 +607,7 @@ def _prepare_anthropic_request( # NOTE: cannot prefill with tools for opus: # Your API request included an `assistant` message in the final position, which would pre-fill the `assistant` response. When using tools with "claude-3-opus-20240229" if prefix_fill and not put_inner_thoughts_in_kwargs and "opus" not in data["model"]: - if not bedrock: # not support for bedrock + if not bedrock: # not support for bedrock data["messages"].append( # Start the thinking process for the assistant {"role": "assistant", "content": f"<{inner_thoughts_xml_tag}>"}, diff --git a/letta/orm/identity.py b/letta/orm/identity.py index 52d7b8ab..4a7cfefd 100644 --- a/letta/orm/identity.py +++ b/letta/orm/identity.py @@ -1,6 +1,7 @@ +import uuid from typing import List, Optional -from sqlalchemy import UniqueConstraint +from sqlalchemy import String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin @@ -15,6 +16,7 @@ class Identity(SqlalchemyBase, OrganizationMixin): __pydantic_model__ = PydanticIdentity __table_args__ = (UniqueConstraint("identifier_key", "project_id", "organization_id", name="unique_identifier_pid_org_id"),) + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"identity-{uuid.uuid4()}") identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.") name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.") identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.") @@ -22,4 +24,16 @@ class Identity(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="identities") - agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="identity") + agents: Mapped[List["Agent"]] = relationship("Agent", lazy="selectin", back_populates="identity") + + def to_pydantic(self) -> PydanticIdentity: + state = { + "id": self.id, + "identifier_key": self.identifier_key, + "name": self.name, + "identity_type": self.identity_type, + "project_id": self.project_id, + "agents": [agent.to_pydantic() for agent in self.agents], + } + + return self.__pydantic_model__(**state) diff --git a/letta/schemas/identity.py b/letta/schemas/identity.py index 5796685b..204826a9 100644 --- a/letta/schemas/identity.py +++ b/letta/schemas/identity.py @@ -22,7 +22,7 @@ class IdentityBase(LettaBase): class Identity(IdentityBase): - id: str = Field(..., description="The internal id of the identity.") + id: str = IdentityBase.generate_id_field() identifier_key: str = Field(..., description="External, user-generated identifier key of the identity.") name: str = Field(..., description="The name of the identity.") identity_type: IdentityType = Field(..., description="The type of the identity.") diff --git a/letta/server/rest_api/routers/v1/identities.py b/letta/server/rest_api/routers/v1/identities.py index ecb6c7cc..83766339 100644 --- a/letta/server/rest_api/routers/v1/identities.py +++ b/letta/server/rest_api/routers/v1/identities.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from letta.orm.errors import NoResultFound from letta.schemas.identity import Identity, IdentityCreate, IdentityType, IdentityUpdate from letta.server.rest_api.utils import get_letta_server @@ -44,7 +45,7 @@ def retrieve_identity( server: "SyncServer" = Depends(get_letta_server), ): try: - return server.identity_manager.get_identity_by_identifier_key(identifier_key=identifier_key) + return server.identity_manager.get_identity_from_identifier_key(identifier_key=identifier_key) except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @@ -56,8 +57,13 @@ def create_identity( user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present project_slug: Optional[str] = Header(None, alias="project-slug"), # Only handled by next js middleware ): - actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.identity_manager.create_identity(identity=identity, actor=actor) + try: + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.identity_manager.create_identity(identity=identity, actor=actor) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") @router.put("/", tags=["identities"], response_model=Identity, operation_id="upsert_identity") @@ -67,8 +73,13 @@ def upsert_identity( user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present project_slug: Optional[str] = Header(None, alias="project-slug"), # Only handled by next js middleware ): - actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.identity_manager.upsert_identity(identity=identity, actor=actor) + try: + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.identity_manager.upsert_identity(identity=identity, actor=actor) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") @router.patch("/{identifier_key}", tags=["identities"], response_model=Identity, operation_id="update_identity") @@ -78,8 +89,13 @@ def modify_identity( server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.identity_manager.update_identity_by_key(identifier_key=identifier_key, identity=identity, actor=actor) + try: + actor = server.user_manager.get_user_or_default(user_id=user_id) + return server.identity_manager.update_identity_by_key(identifier_key=identifier_key, identity=identity, actor=actor) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") @router.delete("/{identifier_key}", tags=["identities"], operation_id="delete_identity") diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 43228e47..3973fa1b 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import List, Optional from fastapi import HTTPException from sqlalchemy.orm import Session @@ -54,57 +54,52 @@ class IdentityManager: @enforce_types def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: with self.session_maker() as session: - agents = self._get_agents_from_ids(session=session, agent_ids=identity.agent_ids, actor=actor) - - identity = IdentityModel.create( - db_session=session, - name=identity.name, - identifier_key=identity.identifier_key, - identity_type=identity.identity_type, - project_id=identity.project_id, - organization_id=actor.organization_id, - agents=agents, - ) - return identity.to_pydantic() + new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids"}, exclude_unset=True)) + new_identity.organization_id = actor.organization_id + self._process_agent_relationship(session=session, identity=new_identity, agent_ids=identity.agent_ids, allow_partial=False) + new_identity.create(session, actor=actor) + return new_identity.to_pydantic() @enforce_types def upsert_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: with self.session_maker() as session: - existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key) - if existing_identity is None: - identity = self.create_identity(identity=identity, actor=actor) - else: - if existing_identity.identifier_key != identity.identifier_key: - raise HTTPException(status_code=400, detail="Identifier key is an immutable field") - if existing_identity.project_id != identity.project_id: - raise HTTPException(status_code=400, detail="Project id is an immutable field") - if existing_identity.organization_id != identity.organization_id: - raise HTTPException(status_code=400, detail="Organization id is an immutable field") - identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids) - identity = self.update_identity_by_key(identity.identifier_key, identity_update, actor) - identity.commit(session) - return identity.to_pydantic() + existing_identity = IdentityModel.read( + db_session=session, + identifier_key=identity.identifier_key, + project_id=identity.project_id, + organization_id=actor.organization_id, + ) + + if existing_identity is None: + return self.create_identity(identity=identity, actor=actor) + else: + if existing_identity.identifier_key != identity.identifier_key: + raise HTTPException(status_code=400, detail="Identifier key is an immutable field") + if existing_identity.project_id != identity.project_id: + raise HTTPException(status_code=400, detail="Project id is an immutable field") + identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids) + return self.update_identity_by_key(identity.identifier_key, identity_update, actor, replace=True) @enforce_types - def update_identity_by_key(self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser) -> PydanticIdentity: + def update_identity_by_key( + self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False + ) -> PydanticIdentity: with self.session_maker() as session: try: existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key) except NoResultFound: raise HTTPException(status_code=404, detail="Identity not found") - if identity.organization_id != existing_identity.organization_id or identity.organization_id != actor.organization_id: + if existing_identity.organization_id != actor.organization_id: raise HTTPException(status_code=403, detail="Forbidden") - agents = None - if identity.agent_ids: - agents = self._get_agents_from_ids(session=session, agent_ids=identity.agent_ids, actor=actor) - existing_identity.name = identity.name if identity.name is not None else existing_identity.name existing_identity.identity_type = ( identity.identity_type if identity.identity_type is not None else existing_identity.identity_type ) - existing_identity.agents = agents if agents is not None else existing_identity.agents - existing_identity.commit(session) + self._process_agent_relationship( + session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace + ) + existing_identity.update(session, actor=actor) return existing_identity.to_pydantic() @enforce_types @@ -116,28 +111,30 @@ class IdentityManager: if identity.organization_id != actor.organization_id: raise HTTPException(status_code=403, detail="Forbidden") session.delete(identity) + session.commit() - def _get_agents_from_ids(self, session: Session, agent_ids: list[str], actor: PydanticUser) -> list[AgentModel]: - """Helper method to get agents from their IDs and verify permissions. + def _process_agent_relationship( + self, session: Session, identity: IdentityModel, agent_ids: List[str], allow_partial=False, replace=True + ): + current_relationship = getattr(identity, "agents", []) + if not agent_ids: + if replace: + setattr(identity, "agents", []) + return - Args: - session: The database session - agent_ids: List of agent IDs to fetch - actor: The user making the request + # Retrieve models for the provided IDs + found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all() - Returns: - List of agent models + # Validate all items are found if allow_partial is False + if not allow_partial and len(found_items) != len(agent_ids): + missing = set(agent_ids) - {item.id for item in found_items} + raise NoResultFound(f"Items not found in agents: {missing}") - Raises: - HTTPException: If agents not found or user doesn't have permission - """ - agents = AgentModel.list(db_session=session, ids=agent_ids) - if len(agents) != len(agent_ids): - found_ids = {agent.id for agent in agents} - missing_ids = [id for id in agent_ids if id not in found_ids] - raise HTTPException(status_code=404, detail=f"Agents not found: {', '.join(missing_ids)}") - - if any(agent.organization_id != actor.organization_id for agent in agents): - raise HTTPException(status_code=403, detail="Forbidden") - - return agents + if replace: + # Replace the relationship + setattr(identity, "agents", found_items) + else: + # Extend the relationship (only add new items) + current_ids = {item.id for item in current_relationship} + new_items = [item for item in found_items if item.id not in current_ids] + current_relationship.extend(new_items)