diff --git a/alembic/versions/d5103ee17ed5_add_template_fields_to_blocks_agents_.py b/alembic/versions/d5103ee17ed5_add_template_fields_to_blocks_agents_.py new file mode 100644 index 00000000..3904d739 --- /dev/null +++ b/alembic/versions/d5103ee17ed5_add_template_fields_to_blocks_agents_.py @@ -0,0 +1,47 @@ +"""add template fields to blocks agents groups + +Revision ID: d5103ee17ed5 +Revises: ffb17eb241fc +Create Date: 2025-08-26 15:45:32.949892 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d5103ee17ed5" +down_revision: Union[str, None] = "ffb17eb241fc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("agents", sa.Column("entity_id", sa.String(), nullable=True)) + op.add_column("agents", sa.Column("deployment_id", sa.String(), nullable=True)) + op.add_column("block", sa.Column("entity_id", sa.String(), nullable=True)) + op.add_column("block", sa.Column("base_template_id", sa.String(), nullable=True)) + op.add_column("block", sa.Column("template_id", sa.String(), nullable=True)) + op.add_column("block", sa.Column("deployment_id", sa.String(), nullable=True)) + op.add_column("groups", sa.Column("base_template_id", sa.String(), nullable=True)) + op.add_column("groups", sa.Column("template_id", sa.String(), nullable=True)) + op.add_column("groups", sa.Column("deployment_id", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("groups", "deployment_id") + op.drop_column("groups", "template_id") + op.drop_column("groups", "base_template_id") + op.drop_column("block", "deployment_id") + op.drop_column("block", "template_id") + op.drop_column("block", "base_template_id") + op.drop_column("block", "entity_id") + op.drop_column("agents", "deployment_id") + op.drop_column("agents", "entity_id") + # ### end Alembic commands ### diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 81d5efbe..e90ec5b8 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.block import Block from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn from letta.orm.identity import Identity -from letta.orm.mixins import OrganizationMixin, ProjectMixin +from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin from letta.orm.organization import Organization from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.agent import AgentState as PydanticAgentState @@ -32,7 +32,7 @@ if TYPE_CHECKING: from letta.orm.tool import Tool -class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, AsyncAttrs): +class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin, AsyncAttrs): __tablename__ = "agents" __pydantic_model__ = PydanticAgentState __table_args__ = (Index("ix_agents_created_at", "created_at", "id"),) @@ -68,8 +68,6 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, AsyncAttrs): embedding_config: Mapped[Optional[EmbeddingConfig]] = mapped_column( EmbeddingConfigColumn, doc="the embedding configuration object for this agent." ) - template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the template the agent belongs to.") - base_template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The base template id of the agent.") # Tool rules tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.") @@ -208,6 +206,8 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, AsyncAttrs): "project_id": self.project_id, "template_id": self.template_id, "base_template_id": self.base_template_id, + "deployment_id": self.deployment_id, + "entity_id": self.entity_id, "tool_rules": self.tool_rules, "message_buffer_autoclear": self.message_buffer_autoclear, "created_by_id": self.created_by_id, @@ -296,6 +296,8 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, AsyncAttrs): "project_id": self.project_id, "template_id": self.template_id, "base_template_id": self.base_template_id, + "deployment_id": self.deployment_id, + "entity_id": self.entity_id, "tool_rules": self.tool_rules, "message_buffer_autoclear": self.message_buffer_autoclear, "created_by_id": self.created_by_id, diff --git a/letta/orm/block.py b/letta/orm/block.py index a31b5a87..d17e89c5 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Mapped, attributes, declared_attr, mapped_column, rel from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents -from letta.orm.mixins import OrganizationMixin, ProjectMixin +from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import Human, Persona @@ -16,7 +16,7 @@ if TYPE_CHECKING: from letta.orm.identity import Identity -class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin): +class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin, TemplateMixin): """Blocks are sections of the LLM context, representing a specific part of the total Memory""" __tablename__ = "block" diff --git a/letta/orm/group.py b/letta/orm/group.py index fe2cfe7e..e819ec12 100644 --- a/letta/orm/group.py +++ b/letta/orm/group.py @@ -4,12 +4,12 @@ from typing import List, Optional from sqlalchemy import JSON, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.orm.mixins import OrganizationMixin, ProjectMixin +from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.group import Group as PydanticGroup -class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin): +class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin): __tablename__ = "groups" __pydantic_model__ = PydanticGroup diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index 13848f17..9358e51c 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -78,3 +78,21 @@ class ArchiveMixin(Base): __abstract__ = True archive_id: Mapped[str] = mapped_column(String, ForeignKey("archives.id", ondelete="CASCADE")) + + +class TemplateMixin(Base): + """TemplateMixin for models that belong to a template.""" + + __abstract__ = True + + base_template_id: Mapped[str] = mapped_column(nullable=True, doc="The id of the base template.") + template_id: Mapped[str] = mapped_column(nullable=True, doc="The id of the template.") + deployment_id: Mapped[str] = mapped_column(nullable=True, doc="The id of the deployment.") + + +class TemplateEntityMixin(Base): + """Mixin for models that belong to an entity (only used for templates).""" + + __abstract__ = True + + entity_id: Mapped[str] = mapped_column(nullable=True, doc="The id of the entity within the template.") diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 1b3349fc..cd00f54b 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -91,6 +91,8 @@ class AgentState(OrmMetadataBase, validate_assignment=True): project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") + deployment_id: Optional[str] = Field(None, description="The id of the deployment.") + entity_id: Optional[str] = Field(None, description="The id of the entity within the template.") identity_ids: List[str] = Field([], description="The ids of the identities associated with this agent.") # An advanced configuration that makes it so this agent does not remember any previous messages @@ -304,6 +306,15 @@ class CreateAgent(BaseModel, validate_assignment=True): # return self +class InternalTemplateAgentCreate(CreateAgent): + """Used for Letta Cloud""" + + base_template_id: str = Field(..., description="The id of the base template.") + template_id: str = Field(..., description="The id of the template.") + deployment_id: str = Field(..., description="The id of the deployment.") + entity_id: str = Field(..., description="The id of the entity within the template.") + + class UpdateAgent(BaseModel): name: Optional[str] = Field(None, description="The name of the agent.") tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.") diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 5fc2f4cc..10864954 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -23,6 +23,10 @@ class BaseBlock(LettaBase, validate_assignment=True): # template data (optional) template_name: Optional[str] = Field(None, description="Name of the block if it is a template.", alias="name") is_template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).") + template_id: Optional[str] = Field(None, description="The id of the template.", alias="name") + base_template_id: Optional[str] = Field(None, description="The base template id of the block.") + deployment_id: Optional[str] = Field(None, description="The id of the deployment.") + entity_id: Optional[str] = Field(None, description="The id of the entity within the template.") preserve_on_migration: Optional[bool] = Field(False, description="Preserve the block on template migration.") # context window label @@ -168,3 +172,12 @@ class CreatePersonaBlockTemplate(CreatePersona): is_template: bool = True label: str = "persona" + + +class InternalTemplateBlockCreate(CreateBlock): + """Used for Letta Cloud""" + + base_template_id: str = Field(..., description="The id of the base template.") + template_id: str = Field(..., description="The id of the template.") + deployment_id: str = Field(..., description="The id of the deployment.") + entity_id: str = Field(..., description="The id of the entity within the template.") diff --git a/letta/schemas/group.py b/letta/schemas/group.py index eb6c6fd8..8cca0948 100644 --- a/letta/schemas/group.py +++ b/letta/schemas/group.py @@ -29,6 +29,10 @@ class Group(GroupBase): agent_ids: List[str] = Field(..., description="") description: str = Field(..., description="") project_id: Optional[str] = Field(None, description="The associated project id.") + # Template fields + template_id: Optional[str] = Field(None, description="The id of the template.") + base_template_id: Optional[str] = Field(None, description="The base template id.") + deployment_id: Optional[str] = Field(None, description="The id of the deployment.") shared_block_ids: List[str] = Field([], description="") # Pattern fields manager_agent_id: Optional[str] = Field(None, description="") @@ -168,6 +172,14 @@ class GroupCreate(BaseModel): shared_block_ids: List[str] = Field([], description="") +class InternalTemplateGroupCreate(GroupCreate): + """Used for Letta Cloud""" + + base_template_id: str = Field(..., description="The id of the base template.") + template_id: str = Field(..., description="The id of the template.") + deployment_id: str = Field(..., description="The id of the deployment.") + + class GroupUpdate(BaseModel): agent_ids: Optional[List[str]] = Field(None, description="") description: Optional[str] = Field(None, description="") diff --git a/letta/server/rest_api/routers/v1/internal_templates.py b/letta/server/rest_api/routers/v1/internal_templates.py new file mode 100644 index 00000000..795f6a42 --- /dev/null +++ b/letta/server/rest_api/routers/v1/internal_templates.py @@ -0,0 +1,68 @@ +from typing import Optional + +from fastapi import APIRouter, Body, Depends, Header, HTTPException + +from letta.schemas.agent import AgentState, InternalTemplateAgentCreate +from letta.schemas.block import Block, InternalTemplateBlockCreate +from letta.schemas.group import Group, InternalTemplateGroupCreate +from letta.server.rest_api.utils import get_letta_server +from letta.server.server import SyncServer + +router = APIRouter(prefix="/_internal_templates", tags=["_internal_templates"]) + + +@router.post("/groups", response_model=Group, operation_id="create_internal_template_group") +async def create_group( + group: InternalTemplateGroupCreate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), + x_project: Optional[str] = Header( + None, alias="X-Project", description="The project slug to associate with the group (cloud only)." + ), # Only handled by next js middleware +): + """ + Create a new multi-agent group with the specified configuration. + """ + try: + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.group_manager.create_group_async(group, actor=actor) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/agents", response_model=AgentState, operation_id="create_internal_template_agent") +async def create_agent( + agent: InternalTemplateAgentCreate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), + x_project: Optional[str] = Header( + None, alias="X-Project", description="The project slug to associate with the agent (cloud only)." + ), # Only handled by next js middleware +): + """ + Create a new agent with template-related fields. + """ + try: + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.create_agent_async(agent, actor=actor) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/blocks", response_model=Block, operation_id="create_internal_template_block") +async def create_block( + block: InternalTemplateBlockCreate = Body(...), + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), + x_project: Optional[str] = Header( + None, alias="X-Project", description="The project slug to associate with the block (cloud only)." + ), # Only handled by next js middleware +): + """ + Create a new block with template-related fields. + """ + try: + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.block_manager.create_or_update_block_async(block, actor=actor) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 4d18f7cd..cd4f180f 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -44,7 +44,7 @@ from letta.orm.sqlalchemy_base import AccessType from letta.otel.tracing import trace_method from letta.prompts.prompt_generator import PromptGenerator from letta.schemas.agent import AgentState as PydanticAgentState -from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent, get_prompt_template_for_agent_type +from letta.schemas.agent import AgentType, CreateAgent, InternalTemplateAgentCreate, UpdateAgent, get_prompt_template_for_agent_type from letta.schemas.block import DEFAULT_BLOCKS from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate @@ -402,6 +402,13 @@ class AgentManager: per_file_view_window_char_limit=agent_create.per_file_view_window_char_limit, ) + # Set template fields for InternalTemplateAgentCreate (similar to group creation) + if isinstance(agent_create, InternalTemplateAgentCreate): + new_agent.base_template_id = agent_create.base_template_id + new_agent.template_id = agent_create.template_id + new_agent.deployment_id = agent_create.deployment_id + new_agent.entity_id = agent_create.entity_id + if _test_only_force_id: new_agent.id = _test_only_force_id @@ -611,6 +618,13 @@ class AgentManager: per_file_view_window_char_limit=agent_create.per_file_view_window_char_limit, ) + # Set template fields for InternalTemplateAgentCreate (similar to group creation) + if isinstance(agent_create, InternalTemplateAgentCreate): + new_agent.base_template_id = agent_create.base_template_id + new_agent.template_id = agent_create.template_id + new_agent.deployment_id = agent_create.deployment_id + new_agent.entity_id = agent_create.entity_id + if _test_only_force_id: new_agent.id = _test_only_force_id diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 91e15d50..7dfffe15 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -9,7 +9,7 @@ from letta.orm.group import Group as GroupModel from letta.orm.message import Message as MessageModel from letta.otel.tracing import trace_method from letta.schemas.group import Group as PydanticGroup -from letta.schemas.group import GroupCreate, GroupUpdate, ManagerType +from letta.schemas.group import GroupCreate, GroupUpdate, InternalTemplateGroupCreate, ManagerType from letta.schemas.letta_message import LettaMessage from letta.schemas.message import Message as PydanticMessage from letta.schemas.user import User as PydanticUser @@ -60,7 +60,7 @@ class GroupManager: @enforce_types @trace_method - def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup: + def create_group(self, group: Union[GroupCreate, InternalTemplateGroupCreate], actor: PydanticUser) -> PydanticGroup: with db_registry.session() as session: new_group = GroupModel() new_group.organization_id = actor.organization_id @@ -96,6 +96,11 @@ class GroupManager: case _: raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") + if isinstance(group, InternalTemplateGroupCreate): + new_group.base_template_id = group.base_template_id + new_group.template_id = group.template_id + new_group.deployment_id = group.deployment_id + self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False) if group.shared_block_ids: @@ -105,7 +110,7 @@ class GroupManager: return new_group.to_pydantic() @enforce_types - async def create_group_async(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup: + async def create_group_async(self, group: Union[GroupCreate, InternalTemplateGroupCreate], actor: PydanticUser) -> PydanticGroup: async with db_registry.async_session() as session: new_group = GroupModel() new_group.organization_id = actor.organization_id @@ -141,6 +146,11 @@ class GroupManager: case _: raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") + if isinstance(group, InternalTemplateGroupCreate): + new_group.base_template_id = group.base_template_id + new_group.template_id = group.template_id + new_group.deployment_id = group.deployment_id + await self._process_agent_relationship_async(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False) if group.shared_block_ids: diff --git a/tests/test_managers.py b/tests/test_managers.py index da5a79c2..3f64fdf0 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4553,7 +4553,8 @@ def test_create_block(server: SyncServer, default_user): label="human", is_template=True, value="Sample content", - template_name="sample_template", + template_name="sample_template_name", + template_id="sample_template", description="A test block", limit=1000, metadata={"example": "data"}, @@ -4566,6 +4567,7 @@ def test_create_block(server: SyncServer, default_user): assert block.is_template == block_create.is_template assert block.value == block_create.value assert block.template_name == block_create.template_name + assert block.template_id == block_create.template_id assert block.description == block_create.description assert block.limit == block_create.limit assert block.metadata == block_create.metadata @@ -10940,6 +10942,78 @@ FAILED tests/test_managers.py::test_high_concurrency_stress_test - AssertionErro # await server.block_manager.delete_block_async(block.id, actor=default_user) +def test_create_internal_template_objects(server: SyncServer, default_user): + """Test creating agents, groups, and blocks with template-related fields.""" + from letta.schemas.agent import InternalTemplateAgentCreate + from letta.schemas.block import Block, InternalTemplateBlockCreate + from letta.schemas.group import InternalTemplateGroupCreate, RoundRobinManager + + base_template_id = "base_123" + template_id = "template_456" + deployment_id = "deploy_789" + entity_id = "entity_012" + + # Create agent with template fields (use sarah_agent as base, then create new one) + agent = server.agent_manager.create_agent( + InternalTemplateAgentCreate( + name="template-agent", + base_template_id=base_template_id, + template_id=template_id, + deployment_id=deployment_id, + entity_id=entity_id, + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, + ), + actor=default_user, + ) + # Verify agent template fields + assert agent.base_template_id == base_template_id + assert agent.template_id == template_id + assert agent.deployment_id == deployment_id + assert agent.entity_id == entity_id + + # Create block with template fields + block_create = InternalTemplateBlockCreate( + label="template_block", + value="Test block", + base_template_id=base_template_id, + template_id=template_id, + deployment_id=deployment_id, + entity_id=entity_id, + ) + block = server.block_manager.create_or_update_block(Block(**block_create.model_dump()), actor=default_user) + # Verify block template fields + assert block.base_template_id == base_template_id + assert block.template_id == template_id + assert block.deployment_id == deployment_id + assert block.entity_id == entity_id + + # Create group with template fields (no entity_id for groups) + group = server.group_manager.create_group( + InternalTemplateGroupCreate( + agent_ids=[agent.id], + description="Template group", + base_template_id=base_template_id, + template_id=template_id, + deployment_id=deployment_id, + manager_config=RoundRobinManager(), + ), + actor=default_user, + ) + # Verify group template fields and basic functionality + assert group.description == "Template group" + assert agent.id in group.agent_ids + assert group.base_template_id == base_template_id + assert group.template_id == template_id + assert group.deployment_id == deployment_id + + # Clean up + server.group_manager.delete_group(group.id, actor=default_user) + server.block_manager.delete_block(block.id, actor=default_user) + server.agent_manager.delete_agent(agent.id, actor=default_user) + + # TODO: I use this as a way to easily wipe my local db lol sorry # TODO: Leave this in here I constantly wipe my db for testing unless you care about optics @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index 3afd97ff..91fe7330 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11, <3.14" resolution-markers = [ "python_full_version >= '3.13'", @@ -2331,7 +2331,7 @@ wheels = [ [[package]] name = "letta" -version = "0.10.0" +version = "0.11.6" source = { editable = "." } dependencies = [ { name = "aiomultiprocess" },