feat: make identities many to many (#1085)
This commit is contained in:
@@ -11,6 +11,7 @@ from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import AgentPassage, AgentsTags
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import Identity as IdentityModel
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import SourcePassage, SourcesAgents
|
||||
from letta.orm import Tool as ToolModel
|
||||
@@ -34,7 +35,6 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import SerializedAgentSchema
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_process_identity,
|
||||
_process_relationship,
|
||||
_process_tags,
|
||||
check_supports_structured_output,
|
||||
@@ -138,6 +138,7 @@ class AgentManager:
|
||||
tool_ids=tool_ids,
|
||||
source_ids=agent_create.source_ids or [],
|
||||
tags=agent_create.tags or [],
|
||||
identity_ids=agent_create.identity_ids or [],
|
||||
description=agent_create.description,
|
||||
metadata=agent_create.metadata,
|
||||
tool_rules=tool_rules,
|
||||
@@ -145,7 +146,6 @@ class AgentManager:
|
||||
project_id=agent_create.project_id,
|
||||
template_id=agent_create.template_id,
|
||||
base_template_id=agent_create.base_template_id,
|
||||
identifier_key=agent_create.identifier_key,
|
||||
message_buffer_autoclear=agent_create.message_buffer_autoclear,
|
||||
)
|
||||
|
||||
@@ -203,13 +203,13 @@ class AgentManager:
|
||||
tool_ids: List[str],
|
||||
source_ids: List[str],
|
||||
tags: List[str],
|
||||
identity_ids: List[str],
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
tool_rules: Optional[List[PydanticToolRule]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
base_template_id: Optional[str] = None,
|
||||
identifier_key: Optional[str] = None,
|
||||
message_buffer_autoclear: bool = False,
|
||||
) -> PydanticAgentState:
|
||||
"""Create a new agent."""
|
||||
@@ -237,9 +237,7 @@ class AgentManager:
|
||||
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
|
||||
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
|
||||
_process_tags(new_agent, tags, replace=True)
|
||||
if identifier_key is not None:
|
||||
identity = self.identity_manager.get_identity_from_identifier_key(identifier_key)
|
||||
_process_identity(new_agent, identifier_key, identity)
|
||||
_process_relationship(session, new_agent, "identities", IdentityModel, identity_ids, replace=True)
|
||||
|
||||
new_agent.create(session, actor=actor)
|
||||
|
||||
@@ -313,9 +311,8 @@ class AgentManager:
|
||||
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
|
||||
if agent_update.tags is not None:
|
||||
_process_tags(agent, agent_update.tags, replace=True)
|
||||
if agent_update.identifier_key is not None:
|
||||
identity = self.identity_manager.get_identity_from_identifier_key(agent_update.identifier_key)
|
||||
_process_identity(agent, agent_update.identifier_key, identity)
|
||||
if agent_update.identity_ids is not None:
|
||||
_process_relationship(session, agent, "identities", IdentityModel, agent_update.identity_ids, replace=True)
|
||||
|
||||
# Commit and refresh the agent
|
||||
agent.update(session, actor=actor)
|
||||
@@ -333,6 +330,7 @@ class AgentManager:
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
query_text: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
@@ -348,6 +346,7 @@ class AgentManager:
|
||||
match_all_tags=match_all_tags,
|
||||
organization_id=actor.organization_id if actor else None,
|
||||
query_text=query_text,
|
||||
identifier_keys=identifier_keys,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.identity import Identity
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate, TextContent
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
@@ -85,20 +84,6 @@ def _process_tags(agent: AgentModel, tags: List[str], replace=True):
|
||||
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
|
||||
|
||||
|
||||
def _process_identity(agent: AgentModel, identifier_key: str, identity: Identity):
|
||||
"""
|
||||
Handles identity for an agent.
|
||||
|
||||
Args:
|
||||
agent: The AgentModel instance.
|
||||
identifier_key: The identifier key of the identity to set or update.
|
||||
identity: The Identity object to set or update.
|
||||
"""
|
||||
agent.identifier_key = identifier_key
|
||||
agent.identity = identity
|
||||
agent.identity_id = identity.id
|
||||
|
||||
|
||||
def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
|
||||
if system is None:
|
||||
# TODO: don't hardcode
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
@@ -23,6 +24,7 @@ class IdentityManager:
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
identifier_key: Optional[str] = None,
|
||||
identity_type: Optional[IdentityType] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
@@ -33,6 +35,8 @@ class IdentityManager:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
if identifier_key:
|
||||
filters["identifier_key"] = identifier_key
|
||||
if identity_type:
|
||||
filters["identity_type"] = identity_type
|
||||
identities = IdentityModel.list(
|
||||
@@ -46,9 +50,9 @@ class IdentityManager:
|
||||
return [identity.to_pydantic() for identity in identities]
|
||||
|
||||
@enforce_types
|
||||
def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity:
|
||||
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
return identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -68,44 +72,56 @@ class IdentityManager:
|
||||
identifier_key=identity.identifier_key,
|
||||
project_id=identity.project_id,
|
||||
organization_id=actor.organization_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if existing_identity is None:
|
||||
return self.create_identity(identity=identity, actor=actor)
|
||||
else:
|
||||
if existing_identity.identifier_key != identity.identifier_key:
|
||||
raise HTTPException(status_code=400, detail="Identifier key is an immutable field")
|
||||
if existing_identity.project_id != identity.project_id:
|
||||
raise HTTPException(status_code=400, detail="Project id is an immutable field")
|
||||
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
|
||||
return self.update_identity_by_key(identity.identifier_key, identity_update, actor, replace=True)
|
||||
return self._update_identity(
|
||||
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
def update_identity_by_key(
|
||||
self, identifier_key: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
|
||||
) -> PydanticIdentity:
|
||||
def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
existing_identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
|
||||
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
if existing_identity.organization_id != actor.organization_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
existing_identity.name = identity.name if identity.name is not None else existing_identity.name
|
||||
existing_identity.identity_type = (
|
||||
identity.identity_type if identity.identity_type is not None else existing_identity.identity_type
|
||||
return self._update_identity(
|
||||
session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
|
||||
)
|
||||
self._process_agent_relationship(
|
||||
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
|
||||
)
|
||||
existing_identity.update(session, actor=actor)
|
||||
return existing_identity.to_pydantic()
|
||||
|
||||
def _update_identity(
|
||||
self,
|
||||
session: Session,
|
||||
existing_identity: IdentityModel,
|
||||
identity: IdentityUpdate,
|
||||
actor: PydanticUser,
|
||||
replace: bool = False,
|
||||
) -> PydanticIdentity:
|
||||
if identity.identifier_key is not None:
|
||||
existing_identity.identifier_key = identity.identifier_key
|
||||
if identity.name is not None:
|
||||
existing_identity.name = identity.name
|
||||
if identity.identity_type is not None:
|
||||
existing_identity.identity_type = identity.identity_type
|
||||
|
||||
self._process_agent_relationship(
|
||||
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
|
||||
)
|
||||
existing_identity.update(session, actor=actor)
|
||||
return existing_identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_identity_by_key(self, identifier_key: str, actor: PydanticUser) -> None:
|
||||
def delete_identity(self, identity_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id)
|
||||
if identity is None:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
if identity.organization_id != actor.organization_id:
|
||||
|
||||
Reference in New Issue
Block a user