feat: make identities many to many (#1085)

This commit is contained in:
cthomas
2025-02-20 16:33:24 -08:00
committed by GitHub
parent afbb5af30b
commit 31130a6d28
13 changed files with 243 additions and 83 deletions

View File

@@ -11,6 +11,7 @@ from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import AgentPassage, AgentsTags
from letta.orm import Block as BlockModel
from letta.orm import Identity as IdentityModel
from letta.orm import Source as SourceModel
from letta.orm import SourcePassage, SourcesAgents
from letta.orm import Tool as ToolModel
@@ -34,7 +35,6 @@ from letta.schemas.user import User as PydanticUser
from letta.serialize_schemas import SerializedAgentSchema
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
_process_identity,
_process_relationship,
_process_tags,
check_supports_structured_output,
@@ -138,6 +138,7 @@ class AgentManager:
tool_ids=tool_ids,
source_ids=agent_create.source_ids or [],
tags=agent_create.tags or [],
identity_ids=agent_create.identity_ids or [],
description=agent_create.description,
metadata=agent_create.metadata,
tool_rules=tool_rules,
@@ -145,7 +146,6 @@ class AgentManager:
project_id=agent_create.project_id,
template_id=agent_create.template_id,
base_template_id=agent_create.base_template_id,
identifier_key=agent_create.identifier_key,
message_buffer_autoclear=agent_create.message_buffer_autoclear,
)
@@ -203,13 +203,13 @@ class AgentManager:
tool_ids: List[str],
source_ids: List[str],
tags: List[str],
identity_ids: List[str],
description: Optional[str] = None,
metadata: Optional[Dict] = None,
tool_rules: Optional[List[PydanticToolRule]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
base_template_id: Optional[str] = None,
identifier_key: Optional[str] = None,
message_buffer_autoclear: bool = False,
) -> PydanticAgentState:
"""Create a new agent."""
@@ -237,9 +237,7 @@ class AgentManager:
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
_process_tags(new_agent, tags, replace=True)
if identifier_key is not None:
identity = self.identity_manager.get_identity_from_identifier_key(identifier_key)
_process_identity(new_agent, identifier_key, identity)
_process_relationship(session, new_agent, "identities", IdentityModel, identity_ids, replace=True)
new_agent.create(session, actor=actor)
@@ -313,9 +311,8 @@ class AgentManager:
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
if agent_update.tags is not None:
_process_tags(agent, agent_update.tags, replace=True)
if agent_update.identifier_key is not None:
identity = self.identity_manager.get_identity_from_identifier_key(agent_update.identifier_key)
_process_identity(agent, agent_update.identifier_key, identity)
if agent_update.identity_ids is not None:
_process_relationship(session, agent, "identities", IdentityModel, agent_update.identity_ids, replace=True)
# Commit and refresh the agent
agent.update(session, actor=actor)
@@ -333,6 +330,7 @@ class AgentManager:
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
query_text: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
**kwargs,
) -> List[PydanticAgentState]:
"""
@@ -348,6 +346,7 @@ class AgentManager:
match_all_tags=match_all_tags,
organization_id=actor.organization_id if actor else None,
query_text=query_text,
identifier_keys=identifier_keys,
**kwargs,
)

View File

@@ -11,7 +11,6 @@ from letta.orm.errors import NoResultFound
from letta.prompts import gpt_system
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole
from letta.schemas.identity import Identity
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate, TextContent
from letta.schemas.tool_rule import ToolRule
@@ -85,20 +84,6 @@ def _process_tags(agent: AgentModel, tags: List[str], replace=True):
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
def _process_identity(agent: AgentModel, identifier_key: str, identity: Identity):
"""
Handles identity for an agent.
Args:
agent: The AgentModel instance.
identifier_key: The identifier key of the identity to set or update.
identity: The Identity object to set or update.
"""
agent.identifier_key = identifier_key
agent.identity = identity
agent.identity_id = identity.id
def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
if system is None:
# TODO: don't hardcode

View File

@@ -1,6 +1,7 @@
from typing import List, Optional
from fastapi import HTTPException
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from letta.orm.agent import Agent as AgentModel
@@ -23,6 +24,7 @@ class IdentityManager:
self,
name: Optional[str] = None,
project_id: Optional[str] = None,
identifier_key: Optional[str] = None,
identity_type: Optional[IdentityType] = None,
before: Optional[str] = None,
after: Optional[str] = None,
@@ -33,6 +35,8 @@ class IdentityManager:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
if identifier_key:
filters["identifier_key"] = identifier_key
if identity_type:
filters["identity_type"] = identity_type
identities = IdentityModel.list(
@@ -46,9 +50,9 @@ class IdentityManager:
return [identity.to_pydantic() for identity in identities]
@enforce_types
def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity:
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
return identity.to_pydantic()
@enforce_types
@@ -68,44 +72,56 @@ class IdentityManager:
identifier_key=identity.identifier_key,
project_id=identity.project_id,
organization_id=actor.organization_id,
actor=actor,
)
if existing_identity is None:
return self.create_identity(identity=identity, actor=actor)
else:
if existing_identity.identifier_key != identity.identifier_key:
raise HTTPException(status_code=400, detail="Identifier key is an immutable field")
if existing_identity.project_id != identity.project_id:
raise HTTPException(status_code=400, detail="Project id is an immutable field")
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
return self.update_identity_by_key(identity.identifier_key, identity_update, actor, replace=True)
return self._update_identity(
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
)
@enforce_types
def update_identity_by_key(
self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
) -> PydanticIdentity:
def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity:
with self.session_maker() as session:
try:
existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Identity not found")
if existing_identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
existing_identity.name = identity.name if identity.name is not None else existing_identity.name
existing_identity.identity_type = (
identity.identity_type if identity.identity_type is not None else existing_identity.identity_type
return self._update_identity(
session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
)
self._process_agent_relationship(
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
)
existing_identity.update(session, actor=actor)
return existing_identity.to_pydantic()
def _update_identity(
self,
session: Session,
existing_identity: IdentityModel,
identity: IdentityUpdate,
actor: PydanticUser,
replace: bool = False,
) -> PydanticIdentity:
if identity.identifier_key is not None:
existing_identity.identifier_key = identity.identifier_key
if identity.name is not None:
existing_identity.name = identity.name
if identity.identity_type is not None:
existing_identity.identity_type = identity.identity_type
self._process_agent_relationship(
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
)
existing_identity.update(session, actor=actor)
return existing_identity.to_pydantic()
@enforce_types
def delete_identity_by_key(self, identifier_key: str, actor: PydanticUser) -> None:
def delete_identity(self, identity_id: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
identity = IdentityModel.read(db_session=session, identifier=identity_id)
if identity is None:
raise HTTPException(status_code=404, detail="Identity not found")
if identity.organization_id != actor.organization_id: