From 3a2a337256d7015ca87b7be098404d976ce7b49b Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 18 Feb 2025 16:06:09 -0800 Subject: [PATCH] feat: add identifier key to agents (#1043) same as https://github.com/letta-ai/letta-cloud/pull/1004 --- ...047a624130_add_identifier_key_to_agents.py | 27 +++++++++++++++++++ letta/orm/agent.py | 5 ++++ letta/schemas/agent.py | 5 ++++ letta/server/rest_api/routers/v1/agents.py | 2 ++ letta/services/agent_manager.py | 12 +++++++++ .../services/helpers/agent_manager_helper.py | 15 +++++++++++ letta/services/identity_manager.py | 17 ++++++++++++ 7 files changed, 83 insertions(+) create mode 100644 alembic/versions/a3047a624130_add_identifier_key_to_agents.py create mode 100644 letta/services/identity_manager.py diff --git a/alembic/versions/a3047a624130_add_identifier_key_to_agents.py b/alembic/versions/a3047a624130_add_identifier_key_to_agents.py new file mode 100644 index 00000000..abeaeef5 --- /dev/null +++ b/alembic/versions/a3047a624130_add_identifier_key_to_agents.py @@ -0,0 +1,27 @@ +"""add identifier key to agents + +Revision ID: a3047a624130 +Revises: a113caac453e +Create Date: 2025-02-14 12:24:16.123456 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a3047a624130" +down_revision: Union[str, None] = "a113caac453e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("agents", sa.Column("identifier_key", sa.String(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("agents", "identifier_key") diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 3555dd7f..918e5fa9 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.block import Block from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn +from letta.orm.identity import Identity from letta.orm.message import Message from letta.orm.mixins import OrganizationMixin from letta.orm.organization import Organization @@ -64,6 +65,9 @@ class Agent(SqlalchemyBase, OrganizationMixin): identity_id: Mapped[Optional[str]] = mapped_column( String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to." ) + identifier_key: Mapped[Optional[str]] = mapped_column( + String, nullable=True, doc="The identifier key of the identity the agent belongs to." + ) # Tool rules tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.") @@ -75,6 +79,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") + identity: Mapped["Identity"] = relationship("Identity", back_populates="agents") tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship( "AgentEnvironmentVariable", back_populates="agent", diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 30c9ed67..2be37939 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -84,6 +84,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True): template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") + # Identity + identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.") + # An advanced configuration that makes it so this agent does not remember any previous messages message_buffer_autoclear: bool = Field( False, @@ -155,6 +158,7 @@ class CreateAgent(BaseModel, validate_assignment=True): # project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") + identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.") message_buffer_autoclear: bool = Field( False, description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", @@ -229,6 +233,7 @@ class UpdateAgent(BaseModel): project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") + identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.") message_buffer_autoclear: Optional[bool] = Field( None, description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index cd57e02e..50a034f2 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -50,6 +50,7 @@ def list_agents( project_id: Optional[str] = Query(None, description="Search agents by project id"), template_id: Optional[str] = Query(None, description="Search agents by template id"), base_template_id: Optional[str] = Query(None, description="Search agents by base template id"), + identifier_key: Optional[str] = Query(None, description="Search agents by identifier key"), ): """ List all agents associated with a given user. @@ -65,6 +66,7 @@ def list_agents( "project_id": project_id, "template_id": template_id, "base_template_id": base_template_id, + "identifier_key": identifier_key, }.items() if value is not None } diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index e2360240..23922326 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -32,6 +32,7 @@ from letta.schemas.user import User as PydanticUser from letta.serialize_schemas import SerializedAgentSchema from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import ( + _process_identity, _process_relationship, _process_tags, check_supports_structured_output, @@ -40,6 +41,7 @@ from letta.services.helpers.agent_manager_helper import ( initialize_message_sequence, package_initial_message_sequence, ) +from letta.services.identity_manager import IdentityManager from letta.services.message_manager import MessageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager @@ -61,6 +63,7 @@ class AgentManager: self.tool_manager = ToolManager() self.source_manager = SourceManager() self.message_manager = MessageManager() + self.identity_manager = IdentityManager() # ====================================================================================================================== # Basic CRUD operations @@ -125,6 +128,7 @@ class AgentManager: project_id=agent_create.project_id, template_id=agent_create.template_id, base_template_id=agent_create.base_template_id, + identifier_key=agent_create.identifier_key, message_buffer_autoclear=agent_create.message_buffer_autoclear, ) @@ -188,6 +192,7 @@ class AgentManager: project_id: Optional[str] = None, template_id: Optional[str] = None, base_template_id: Optional[str] = None, + identifier_key: Optional[str] = None, message_buffer_autoclear: bool = False, ) -> PydanticAgentState: """Create a new agent.""" @@ -215,6 +220,10 @@ class AgentManager: _process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True) _process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True) _process_tags(new_agent, tags, replace=True) + if identifier_key is not None: + identity = self.identity_manager.get_identity_from_identifier_key(identifier_key) + _process_identity(new_agent, identifier_key, identity) + new_agent.create(session, actor=actor) # Convert to PydanticAgentState and return @@ -287,6 +296,9 @@ class AgentManager: _process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True) if agent_update.tags is not None: _process_tags(agent, agent_update.tags, replace=True) + if agent_update.identifier_key is not None: + identity = self.identity_manager.get_identity_from_identifier_key(agent_update.identifier_key) + _process_identity(agent, agent_update.identifier_key, identity) # Commit and refresh the agent agent.update(session, actor=actor) diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 8d99449c..098ffc52 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -11,6 +11,7 @@ from letta.orm.errors import NoResultFound from letta.prompts import gpt_system from letta.schemas.agent import AgentState, AgentType from letta.schemas.enums import MessageRole +from letta.schemas.identity import Identity from letta.schemas.memory import Memory from letta.schemas.message import Message, MessageCreate, TextContent from letta.schemas.tool_rule import ToolRule @@ -84,6 +85,20 @@ def _process_tags(agent: AgentModel, tags: List[str], replace=True): agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags]) +def _process_identity(agent: AgentModel, identifier_key: str, identity: Identity): + """ + Handles identity for an agent. + + Args: + agent: The AgentModel instance. + identifier_key: The identifier key of the identity to set or update. + identity: The Identity object to set or update. + """ + agent.identifier_key = identifier_key + agent.identity = identity + agent.identity_id = identity.id + + def derive_system_message(agent_type: AgentType, system: Optional[str] = None): if system is None: # TODO: don't hardcode diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py new file mode 100644 index 00000000..9d996db7 --- /dev/null +++ b/letta/services/identity_manager.py @@ -0,0 +1,17 @@ +from letta.orm.identity import Identity as IdentityModel +from letta.schemas.identity import Identity as PydanticIdentity +from letta.utils import enforce_types + + +class IdentityManager: + + def __init__(self): + from letta.server.db import db_context + + self.session_maker = db_context + + @enforce_types + def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity: + with self.session_maker() as session: + identity = IdentityModel.read(db_session=session, identifier_key=identifier_key) + return identity.to_pydantic()