feat: add identifier key to agents (#1043)

same as https://github.com/letta-ai/letta-cloud/pull/1004
This commit is contained in:
cthomas
2025-02-18 16:06:09 -08:00
committed by GitHub
parent 98f0062416
commit 3a2a337256
7 changed files with 83 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
"""add identifier key to agents
Revision ID: a3047a624130
Revises: a113caac453e
Create Date: 2025-02-14 12:24:16.123456
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "a3047a624130"
down_revision: Union[str, None] = "a113caac453e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("agents", sa.Column("identifier_key", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("agents", "identifier_key")

View File

@@ -6,6 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
from letta.orm.identity import Identity
from letta.orm.message import Message
from letta.orm.mixins import OrganizationMixin
from letta.orm.organization import Organization
@@ -64,6 +65,9 @@ class Agent(SqlalchemyBase, OrganizationMixin):
identity_id: Mapped[Optional[str]] = mapped_column(
String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to."
)
identifier_key: Mapped[Optional[str]] = mapped_column(
String, nullable=True, doc="The identifier key of the identity the agent belongs to."
)
# Tool rules
tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.")
@@ -75,6 +79,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
identity: Mapped["Identity"] = relationship("Identity", back_populates="agents")
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
"AgentEnvironmentVariable",
back_populates="agent",

View File

@@ -84,6 +84,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
# Identity
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
# An advanced configuration that makes it so this agent does not remember any previous messages
message_buffer_autoclear: bool = Field(
False,
@@ -155,6 +158,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
message_buffer_autoclear: bool = Field(
False,
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
@@ -229,6 +233,7 @@ class UpdateAgent(BaseModel):
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
message_buffer_autoclear: Optional[bool] = Field(
None,
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",

View File

@@ -50,6 +50,7 @@ def list_agents(
project_id: Optional[str] = Query(None, description="Search agents by project id"),
template_id: Optional[str] = Query(None, description="Search agents by template id"),
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
identifier_key: Optional[str] = Query(None, description="Search agents by identifier key"),
):
"""
List all agents associated with a given user.
@@ -65,6 +66,7 @@ def list_agents(
"project_id": project_id,
"template_id": template_id,
"base_template_id": base_template_id,
"identifier_key": identifier_key,
}.items()
if value is not None
}

View File

@@ -32,6 +32,7 @@ from letta.schemas.user import User as PydanticUser
from letta.serialize_schemas import SerializedAgentSchema
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
_process_identity,
_process_relationship,
_process_tags,
check_supports_structured_output,
@@ -40,6 +41,7 @@ from letta.services.helpers.agent_manager_helper import (
initialize_message_sequence,
package_initial_message_sequence,
)
from letta.services.identity_manager import IdentityManager
from letta.services.message_manager import MessageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
@@ -61,6 +63,7 @@ class AgentManager:
self.tool_manager = ToolManager()
self.source_manager = SourceManager()
self.message_manager = MessageManager()
self.identity_manager = IdentityManager()
# ======================================================================================================================
# Basic CRUD operations
@@ -125,6 +128,7 @@ class AgentManager:
project_id=agent_create.project_id,
template_id=agent_create.template_id,
base_template_id=agent_create.base_template_id,
identifier_key=agent_create.identifier_key,
message_buffer_autoclear=agent_create.message_buffer_autoclear,
)
@@ -188,6 +192,7 @@ class AgentManager:
project_id: Optional[str] = None,
template_id: Optional[str] = None,
base_template_id: Optional[str] = None,
identifier_key: Optional[str] = None,
message_buffer_autoclear: bool = False,
) -> PydanticAgentState:
"""Create a new agent."""
@@ -215,6 +220,10 @@ class AgentManager:
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
_process_tags(new_agent, tags, replace=True)
if identifier_key is not None:
identity = self.identity_manager.get_identity_from_identifier_key(identifier_key)
_process_identity(new_agent, identifier_key, identity)
new_agent.create(session, actor=actor)
# Convert to PydanticAgentState and return
@@ -287,6 +296,9 @@ class AgentManager:
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
if agent_update.tags is not None:
_process_tags(agent, agent_update.tags, replace=True)
if agent_update.identifier_key is not None:
identity = self.identity_manager.get_identity_from_identifier_key(agent_update.identifier_key)
_process_identity(agent, agent_update.identifier_key, identity)
# Commit and refresh the agent
agent.update(session, actor=actor)

View File

@@ -11,6 +11,7 @@ from letta.orm.errors import NoResultFound
from letta.prompts import gpt_system
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole
from letta.schemas.identity import Identity
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate, TextContent
from letta.schemas.tool_rule import ToolRule
@@ -84,6 +85,20 @@ def _process_tags(agent: AgentModel, tags: List[str], replace=True):
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
def _process_identity(agent: AgentModel, identifier_key: str, identity: Identity):
"""
Handles identity for an agent.
Args:
agent: The AgentModel instance.
identifier_key: The identifier key of the identity to set or update.
identity: The Identity object to set or update.
"""
agent.identifier_key = identifier_key
agent.identity = identity
agent.identity_id = identity.id
def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
if system is None:
# TODO: don't hardcode

View File

@@ -0,0 +1,17 @@
from letta.orm.identity import Identity as IdentityModel
from letta.schemas.identity import Identity as PydanticIdentity
from letta.utils import enforce_types
class IdentityManager:
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity:
with self.session_maker() as session:
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
return identity.to_pydantic()