feat: project_id uniqueness for tools (#6604)

* feat: project_id uniqueness for tools

* prevent double upsert of global tools

* use default project if no header for sdk

* reorder unique constraint for performance

* use separate session for check conflict

* feature flag adding project id header in cloud api

* add my migration after one on main

* remove comment

* stage and publish api

* web set project id just for tools

* includes instead of startswith
This commit is contained in:
Ari Webb
2025-12-12 13:01:36 -08:00
committed by Caren Thomas
parent 22b9ed254a
commit c1aa01db6f
3 changed files with 127 additions and 5 deletions

View File

@@ -0,0 +1,35 @@
"""unique project_id for tools
Revision ID: b3c920939d81
Revises: d0880aae6cee
Create Date: 2025-12-10 18:22:57.294283
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "b3c920939d81"
down_revision: Union[str, None] = "d0880aae6cee"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(op.f("uix_name_organization"), "tools", type_="unique")
op.create_unique_constraint(
"uix_organization_project_name", "tools", ["organization_id", "project_id", "name"], postgresql_nulls_not_distinct=True
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("uix_organization_project_name", "tools", type_="unique")
op.create_unique_constraint(op.f("uix_name_organization"), "tools", ["name", "organization_id"], postgresql_nulls_not_distinct=False)
# ### end Alembic commands ###

View File

@@ -28,7 +28,7 @@ class Tool(SqlalchemyBase, OrganizationMixin, ProjectMixin):
# Add unique constraint on (name, _organization_id)
# An organization should not have multiple tools with the same name
__table_args__ = (
UniqueConstraint("name", "organization_id", name="uix_name_organization"),
UniqueConstraint("organization_id", "project_id", "name", name="uix_organization_project_name", postgresql_nulls_not_distinct=True),
Index("ix_tools_created_at_name", "created_at", "name"),
Index("ix_tools_organization_id", "organization_id"),
Index("ix_tools_organization_id_name", "organization_id", "name"),

View File

@@ -297,6 +297,16 @@ class ToolManager:
pydantic_tool.metadata_ = {}
pydantic_tool.metadata_["tool_hash"] = tool_hash
# Check for tool name conflicts across projects
# This prevents name collisions between global tools and project-scoped tools
has_conflict = await self._check_tool_name_conflict_across_projects_async(
tool_name=pydantic_tool.name,
project_id=pydantic_tool.project_id,
actor=actor,
)
if has_conflict:
raise LettaToolNameConflictError(tool_name=pydantic_tool.name)
async with db_registry.async_session() as session:
table = ToolModel.__table__
valid_columns = {col.name for col in table.columns}
@@ -329,7 +339,9 @@ class ToolManager:
else:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict).returning(table.c.id)
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id", "project_id"], set_=update_dict).returning(
table.c.id
)
result = await session.execute(upsert_stmt)
tool_id = result.scalar_one()
@@ -395,8 +407,17 @@ class ToolManager:
self, pydantic_tool: PydanticTool, actor: PydanticUser, modal_sandbox_enabled: bool = False
) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Generate schema only if not provided (only for custom tools)
# Check for tool name conflicts across projects
# This prevents name collisions between global tools and project-scoped tools
has_conflict = await self._check_tool_name_conflict_across_projects_async(
tool_name=pydantic_tool.name,
project_id=pydantic_tool.project_id,
actor=actor,
)
if has_conflict:
raise LettaToolNameConflictError(tool_name=pydantic_tool.name)
# Generate schema only if not provided (only for custom tools)
async with db_registry.async_session() as session:
# Auto-generate description if not provided
if pydantic_tool.description is None and pydantic_tool.json_schema:
@@ -599,6 +620,61 @@ class ToolManager:
existing_tool = result.scalar()
return existing_tool is not None
@enforce_types
async def _check_tool_name_conflict_across_projects_async(
self,
tool_name: str,
project_id: Optional[str],
actor: PydanticUser,
exclude_tool_id: Optional[str] = None,
) -> bool:
"""Check for tool name conflicts across projects within an organization.
This handles the uniqueness constraint where:
- If project_id is supplied: ensure no global tools (project_id=None) exist with same name
- If project_id is not supplied (global tool): ensure no tools in any project exist with same name
Uses SELECT FOR UPDATE to prevent race conditions.
Args:
session: The database session (must be part of an active transaction)
tool_name: The name to check for conflicts
project_id: The project_id of the tool being created/updated (None for global tools)
actor: The user performing the action
exclude_tool_id: Optional ID of tool to exclude from check (for updates)
Returns:
True if a conflicting tool exists, False otherwise
"""
# Build query based on project_id
if project_id is not None:
# Creating/updating a project-scoped tool
# Check if there's a global tool (project_id=None) with the same name
query = select(ToolModel.id).where(
ToolModel.name == tool_name,
ToolModel.organization_id == actor.organization_id,
ToolModel.project_id.is_(None), # Only check global tools
)
else:
# Creating/updating a global tool (project_id=None)
# Check if there's any tool in any project with the same name
query = select(ToolModel.id).where(
ToolModel.name == tool_name,
ToolModel.organization_id == actor.organization_id,
ToolModel.project_id.isnot(None), # Only check project-scoped tools
)
# Exclude current tool if updating
if exclude_tool_id is not None:
query = query.where(ToolModel.id != exclude_tool_id)
# Use FOR UPDATE to prevent race conditions
query = query.with_for_update(nowait=False)
async with db_registry.async_session() as session:
result = await session.execute(query)
existing_tool = result.scalar()
return existing_tool is not None
@enforce_types
@trace_method
async def list_tools_async(
@@ -1000,6 +1076,17 @@ class ToolManager:
if name_conflict:
raise LettaToolNameConflictError(tool_name=new_name)
# Check for tool name conflicts across projects (global vs project-scoped)
# This prevents name collisions between global tools and project-scoped tools
cross_project_conflict = await self._check_tool_name_conflict_across_projects_async(
tool_name=new_name,
project_id=current_tool.project_id,
actor=actor,
exclude_tool_id=tool_id,
)
if cross_project_conflict:
raise LettaToolNameConflictError(tool_name=new_name)
# Fetch the tool by ID
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
@@ -1220,10 +1307,10 @@ class ToolManager:
else:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id", "project_id"], set_=update_dict)
else:
# on conflict, do nothing (skip existing tools)
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"])
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id", "project_id"])
await session.execute(upsert_stmt)
await session.commit()