feat: add identifier key to agents (#1043)
same as https://github.com/letta-ai/letta-cloud/pull/1004
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
"""add identifier key to agents
|
||||
|
||||
Revision ID: a3047a624130
|
||||
Revises: a113caac453e
|
||||
Create Date: 2025-02-14 12:24:16.123456
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a3047a624130"
|
||||
down_revision: Union[str, None] = "a113caac453e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("agents", sa.Column("identifier_key", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("agents", "identifier_key")
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.organization import Organization
|
||||
@@ -64,6 +65,9 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
identity_id: Mapped[Optional[str]] = mapped_column(
|
||||
String, ForeignKey("identities.id", ondelete="CASCADE"), nullable=True, doc="The id of the identity the agent belongs to."
|
||||
)
|
||||
identifier_key: Mapped[Optional[str]] = mapped_column(
|
||||
String, nullable=True, doc="The identifier key of the identity the agent belongs to."
|
||||
)
|
||||
|
||||
# Tool rules
|
||||
tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.")
|
||||
@@ -75,6 +79,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
|
||||
identity: Mapped["Identity"] = relationship("Identity", back_populates="agents")
|
||||
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
|
||||
"AgentEnvironmentVariable",
|
||||
back_populates="agent",
|
||||
|
||||
@@ -84,6 +84,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
|
||||
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
|
||||
|
||||
# Identity
|
||||
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
|
||||
|
||||
# An advanced configuration that makes it so this agent does not remember any previous messages
|
||||
message_buffer_autoclear: bool = Field(
|
||||
False,
|
||||
@@ -155,6 +158,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
|
||||
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
|
||||
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
|
||||
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
|
||||
message_buffer_autoclear: bool = Field(
|
||||
False,
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
@@ -229,6 +233,7 @@ class UpdateAgent(BaseModel):
|
||||
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
|
||||
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
|
||||
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")
|
||||
identifier_key: Optional[str] = Field(None, description="The identifier key belonging to the identity associated with this agent.")
|
||||
message_buffer_autoclear: Optional[bool] = Field(
|
||||
None,
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
|
||||
@@ -50,6 +50,7 @@ def list_agents(
|
||||
project_id: Optional[str] = Query(None, description="Search agents by project id"),
|
||||
template_id: Optional[str] = Query(None, description="Search agents by template id"),
|
||||
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
|
||||
identifier_key: Optional[str] = Query(None, description="Search agents by identifier key"),
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
@@ -65,6 +66,7 @@ def list_agents(
|
||||
"project_id": project_id,
|
||||
"template_id": template_id,
|
||||
"base_template_id": base_template_id,
|
||||
"identifier_key": identifier_key,
|
||||
}.items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import SerializedAgentSchema
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_process_identity,
|
||||
_process_relationship,
|
||||
_process_tags,
|
||||
check_supports_structured_output,
|
||||
@@ -40,6 +41,7 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
initialize_message_sequence,
|
||||
package_initial_message_sequence,
|
||||
)
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
@@ -61,6 +63,7 @@ class AgentManager:
|
||||
self.tool_manager = ToolManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.message_manager = MessageManager()
|
||||
self.identity_manager = IdentityManager()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Basic CRUD operations
|
||||
@@ -125,6 +128,7 @@ class AgentManager:
|
||||
project_id=agent_create.project_id,
|
||||
template_id=agent_create.template_id,
|
||||
base_template_id=agent_create.base_template_id,
|
||||
identifier_key=agent_create.identifier_key,
|
||||
message_buffer_autoclear=agent_create.message_buffer_autoclear,
|
||||
)
|
||||
|
||||
@@ -188,6 +192,7 @@ class AgentManager:
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
base_template_id: Optional[str] = None,
|
||||
identifier_key: Optional[str] = None,
|
||||
message_buffer_autoclear: bool = False,
|
||||
) -> PydanticAgentState:
|
||||
"""Create a new agent."""
|
||||
@@ -215,6 +220,10 @@ class AgentManager:
|
||||
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
|
||||
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
|
||||
_process_tags(new_agent, tags, replace=True)
|
||||
if identifier_key is not None:
|
||||
identity = self.identity_manager.get_identity_from_identifier_key(identifier_key)
|
||||
_process_identity(new_agent, identifier_key, identity)
|
||||
|
||||
new_agent.create(session, actor=actor)
|
||||
|
||||
# Convert to PydanticAgentState and return
|
||||
@@ -287,6 +296,9 @@ class AgentManager:
|
||||
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
|
||||
if agent_update.tags is not None:
|
||||
_process_tags(agent, agent_update.tags, replace=True)
|
||||
if agent_update.identifier_key is not None:
|
||||
identity = self.identity_manager.get_identity_from_identifier_key(agent_update.identifier_key)
|
||||
_process_identity(agent, agent_update.identifier_key, identity)
|
||||
|
||||
# Commit and refresh the agent
|
||||
agent.update(session, actor=actor)
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.identity import Identity
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate, TextContent
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
@@ -84,6 +85,20 @@ def _process_tags(agent: AgentModel, tags: List[str], replace=True):
|
||||
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
|
||||
|
||||
|
||||
def _process_identity(agent: AgentModel, identifier_key: str, identity: Identity):
|
||||
"""
|
||||
Handles identity for an agent.
|
||||
|
||||
Args:
|
||||
agent: The AgentModel instance.
|
||||
identifier_key: The identifier key of the identity to set or update.
|
||||
identity: The Identity object to set or update.
|
||||
"""
|
||||
agent.identifier_key = identifier_key
|
||||
agent.identity = identity
|
||||
agent.identity_id = identity.id
|
||||
|
||||
|
||||
def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
|
||||
if system is None:
|
||||
# TODO: don't hardcode
|
||||
|
||||
17
letta/services/identity_manager.py
Normal file
17
letta/services/identity_manager.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from letta.orm.identity import Identity as IdentityModel
|
||||
from letta.schemas.identity import Identity as PydanticIdentity
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class IdentityManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def get_identity_from_identifier_key(self, identifier_key: str) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier_key=identifier_key)
|
||||
return identity.to_pydantic()
|
||||
Reference in New Issue
Block a user