diff --git a/alembic/versions/b3c920939d81_unique_project_id_for_tools.py b/alembic/versions/b3c920939d81_unique_project_id_for_tools.py new file mode 100644 index 00000000..8bbab921 --- /dev/null +++ b/alembic/versions/b3c920939d81_unique_project_id_for_tools.py @@ -0,0 +1,35 @@ +"""unique project_id for tools + +Revision ID: b3c920939d81 +Revises: d0880aae6cee +Create Date: 2025-12-10 18:22:57.294283 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b3c920939d81" +down_revision: Union[str, None] = "d0880aae6cee" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f("uix_name_organization"), "tools", type_="unique") + op.create_unique_constraint( + "uix_organization_project_name", "tools", ["organization_id", "project_id", "name"], postgresql_nulls_not_distinct=True + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("uix_organization_project_name", "tools", type_="unique") + op.create_unique_constraint(op.f("uix_name_organization"), "tools", ["name", "organization_id"], postgresql_nulls_not_distinct=False) + # ### end Alembic commands ### diff --git a/letta/orm/tool.py b/letta/orm/tool.py index 2034cc35..f0dd0385 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -28,7 +28,7 @@ class Tool(SqlalchemyBase, OrganizationMixin, ProjectMixin): # Add unique constraint on (name, _organization_id) # An organization should not have multiple tools with the same name __table_args__ = ( - UniqueConstraint("name", "organization_id", name="uix_name_organization"), + UniqueConstraint("organization_id", "project_id", "name", name="uix_organization_project_name", postgresql_nulls_not_distinct=True), Index("ix_tools_created_at_name", "created_at", "name"), Index("ix_tools_organization_id", "organization_id"), Index("ix_tools_organization_id_name", "organization_id", "name"), diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index de3ff428..cc5feee6 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -297,6 +297,16 @@ class ToolManager: pydantic_tool.metadata_ = {} pydantic_tool.metadata_["tool_hash"] = tool_hash + # Check for tool name conflicts across projects + # This prevents name collisions between global tools and project-scoped tools + has_conflict = await self._check_tool_name_conflict_across_projects_async( + tool_name=pydantic_tool.name, + project_id=pydantic_tool.project_id, + actor=actor, + ) + if has_conflict: + raise LettaToolNameConflictError(tool_name=pydantic_tool.name) + async with db_registry.async_session() as session: table = ToolModel.__table__ valid_columns = {col.name for col in table.columns} @@ -329,7 +339,9 @@ class ToolManager: else: update_dict[col.name] = excluded[col.name] - upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict).returning(table.c.id) + upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id", "project_id"], set_=update_dict).returning( + table.c.id + ) result = await session.execute(upsert_stmt) tool_id = result.scalar_one() @@ -395,8 +407,17 @@ class ToolManager: self, pydantic_tool: PydanticTool, actor: PydanticUser, modal_sandbox_enabled: bool = False ) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" - # Generate schema only if not provided (only for custom tools) + # Check for tool name conflicts across projects + # This prevents name collisions between global tools and project-scoped tools + has_conflict = await self._check_tool_name_conflict_across_projects_async( + tool_name=pydantic_tool.name, + project_id=pydantic_tool.project_id, + actor=actor, + ) + if has_conflict: + raise LettaToolNameConflictError(tool_name=pydantic_tool.name) + # Generate schema only if not provided (only for custom tools) async with db_registry.async_session() as session: # Auto-generate description if not provided if pydantic_tool.description is None and pydantic_tool.json_schema: @@ -599,6 +620,61 @@ class ToolManager: existing_tool = result.scalar() return existing_tool is not None + @enforce_types + async def _check_tool_name_conflict_across_projects_async( + self, + tool_name: str, + project_id: Optional[str], + actor: PydanticUser, + exclude_tool_id: Optional[str] = None, + ) -> bool: + """Check for tool name conflicts across projects within an organization. + + This handles the uniqueness constraint where: + - If project_id is supplied: ensure no global tools (project_id=None) exist with same name + - If project_id is not supplied (global tool): ensure no tools in any project exist with same name + + Uses SELECT FOR UPDATE to prevent race conditions. + + Args: + session: The database session (must be part of an active transaction) + tool_name: The name to check for conflicts + project_id: The project_id of the tool being created/updated (None for global tools) + actor: The user performing the action + exclude_tool_id: Optional ID of tool to exclude from check (for updates) + + Returns: + True if a conflicting tool exists, False otherwise + """ + # Build query based on project_id + if project_id is not None: + # Creating/updating a project-scoped tool + # Check if there's a global tool (project_id=None) with the same name + query = select(ToolModel.id).where( + ToolModel.name == tool_name, + ToolModel.organization_id == actor.organization_id, + ToolModel.project_id.is_(None), # Only check global tools + ) + else: + # Creating/updating a global tool (project_id=None) + # Check if there's any tool in any project with the same name + query = select(ToolModel.id).where( + ToolModel.name == tool_name, + ToolModel.organization_id == actor.organization_id, + ToolModel.project_id.isnot(None), # Only check project-scoped tools + ) + + # Exclude current tool if updating + if exclude_tool_id is not None: + query = query.where(ToolModel.id != exclude_tool_id) + + # Use FOR UPDATE to prevent race conditions + query = query.with_for_update(nowait=False) + async with db_registry.async_session() as session: + result = await session.execute(query) + existing_tool = result.scalar() + return existing_tool is not None + @enforce_types @trace_method async def list_tools_async( @@ -1000,6 +1076,17 @@ class ToolManager: if name_conflict: raise LettaToolNameConflictError(tool_name=new_name) + # Check for tool name conflicts across projects (global vs project-scoped) + # This prevents name collisions between global tools and project-scoped tools + cross_project_conflict = await self._check_tool_name_conflict_across_projects_async( + tool_name=new_name, + project_id=current_tool.project_id, + actor=actor, + exclude_tool_id=tool_id, + ) + if cross_project_conflict: + raise LettaToolNameConflictError(tool_name=new_name) + # Fetch the tool by ID tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) @@ -1220,10 +1307,10 @@ class ToolManager: else: update_dict[col.name] = excluded[col.name] - upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict) + upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id", "project_id"], set_=update_dict) else: # on conflict, do nothing (skip existing tools) - upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"]) + upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id", "project_id"]) await session.execute(upsert_stmt) await session.commit()