feat: project_id uniqueness for tools (#6604)
* feat: project_id uniqueness for tools * prevent double upsert of global tools * use default project if no header for sdk * reorder unique constraint for performance * use separate session for check conflict * feature flag adding project id header in cloud api * add my migration after one on main * remove comment * stage and publish api * web set project id just for tools * includes instead of startswith
This commit is contained in:
35
alembic/versions/b3c920939d81_unique_project_id_for_tools.py
Normal file
35
alembic/versions/b3c920939d81_unique_project_id_for_tools.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""unique project_id for tools
|
||||
|
||||
Revision ID: b3c920939d81
|
||||
Revises: d0880aae6cee
|
||||
Create Date: 2025-12-10 18:22:57.294283
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b3c920939d81"
|
||||
down_revision: Union[str, None] = "d0880aae6cee"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(op.f("uix_name_organization"), "tools", type_="unique")
|
||||
op.create_unique_constraint(
|
||||
"uix_organization_project_name", "tools", ["organization_id", "project_id", "name"], postgresql_nulls_not_distinct=True
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("uix_organization_project_name", "tools", type_="unique")
|
||||
op.create_unique_constraint(op.f("uix_name_organization"), "tools", ["name", "organization_id"], postgresql_nulls_not_distinct=False)
|
||||
# ### end Alembic commands ###
|
||||
@@ -28,7 +28,7 @@ class Tool(SqlalchemyBase, OrganizationMixin, ProjectMixin):
|
||||
# Add unique constraint on (name, _organization_id)
|
||||
# An organization should not have multiple tools with the same name
|
||||
__table_args__ = (
|
||||
UniqueConstraint("name", "organization_id", name="uix_name_organization"),
|
||||
UniqueConstraint("organization_id", "project_id", "name", name="uix_organization_project_name", postgresql_nulls_not_distinct=True),
|
||||
Index("ix_tools_created_at_name", "created_at", "name"),
|
||||
Index("ix_tools_organization_id", "organization_id"),
|
||||
Index("ix_tools_organization_id_name", "organization_id", "name"),
|
||||
|
||||
@@ -297,6 +297,16 @@ class ToolManager:
|
||||
pydantic_tool.metadata_ = {}
|
||||
pydantic_tool.metadata_["tool_hash"] = tool_hash
|
||||
|
||||
# Check for tool name conflicts across projects
|
||||
# This prevents name collisions between global tools and project-scoped tools
|
||||
has_conflict = await self._check_tool_name_conflict_across_projects_async(
|
||||
tool_name=pydantic_tool.name,
|
||||
project_id=pydantic_tool.project_id,
|
||||
actor=actor,
|
||||
)
|
||||
if has_conflict:
|
||||
raise LettaToolNameConflictError(tool_name=pydantic_tool.name)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
table = ToolModel.__table__
|
||||
valid_columns = {col.name for col in table.columns}
|
||||
@@ -329,7 +339,9 @@ class ToolManager:
|
||||
else:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict).returning(table.c.id)
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id", "project_id"], set_=update_dict).returning(
|
||||
table.c.id
|
||||
)
|
||||
|
||||
result = await session.execute(upsert_stmt)
|
||||
tool_id = result.scalar_one()
|
||||
@@ -395,8 +407,17 @@ class ToolManager:
|
||||
self, pydantic_tool: PydanticTool, actor: PydanticUser, modal_sandbox_enabled: bool = False
|
||||
) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Generate schema only if not provided (only for custom tools)
|
||||
# Check for tool name conflicts across projects
|
||||
# This prevents name collisions between global tools and project-scoped tools
|
||||
has_conflict = await self._check_tool_name_conflict_across_projects_async(
|
||||
tool_name=pydantic_tool.name,
|
||||
project_id=pydantic_tool.project_id,
|
||||
actor=actor,
|
||||
)
|
||||
if has_conflict:
|
||||
raise LettaToolNameConflictError(tool_name=pydantic_tool.name)
|
||||
|
||||
# Generate schema only if not provided (only for custom tools)
|
||||
async with db_registry.async_session() as session:
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None and pydantic_tool.json_schema:
|
||||
@@ -599,6 +620,61 @@ class ToolManager:
|
||||
existing_tool = result.scalar()
|
||||
return existing_tool is not None
|
||||
|
||||
@enforce_types
|
||||
async def _check_tool_name_conflict_across_projects_async(
|
||||
self,
|
||||
tool_name: str,
|
||||
project_id: Optional[str],
|
||||
actor: PydanticUser,
|
||||
exclude_tool_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Check for tool name conflicts across projects within an organization.
|
||||
|
||||
This handles the uniqueness constraint where:
|
||||
- If project_id is supplied: ensure no global tools (project_id=None) exist with same name
|
||||
- If project_id is not supplied (global tool): ensure no tools in any project exist with same name
|
||||
|
||||
Uses SELECT FOR UPDATE to prevent race conditions.
|
||||
|
||||
Args:
|
||||
session: The database session (must be part of an active transaction)
|
||||
tool_name: The name to check for conflicts
|
||||
project_id: The project_id of the tool being created/updated (None for global tools)
|
||||
actor: The user performing the action
|
||||
exclude_tool_id: Optional ID of tool to exclude from check (for updates)
|
||||
|
||||
Returns:
|
||||
True if a conflicting tool exists, False otherwise
|
||||
"""
|
||||
# Build query based on project_id
|
||||
if project_id is not None:
|
||||
# Creating/updating a project-scoped tool
|
||||
# Check if there's a global tool (project_id=None) with the same name
|
||||
query = select(ToolModel.id).where(
|
||||
ToolModel.name == tool_name,
|
||||
ToolModel.organization_id == actor.organization_id,
|
||||
ToolModel.project_id.is_(None), # Only check global tools
|
||||
)
|
||||
else:
|
||||
# Creating/updating a global tool (project_id=None)
|
||||
# Check if there's any tool in any project with the same name
|
||||
query = select(ToolModel.id).where(
|
||||
ToolModel.name == tool_name,
|
||||
ToolModel.organization_id == actor.organization_id,
|
||||
ToolModel.project_id.isnot(None), # Only check project-scoped tools
|
||||
)
|
||||
|
||||
# Exclude current tool if updating
|
||||
if exclude_tool_id is not None:
|
||||
query = query.where(ToolModel.id != exclude_tool_id)
|
||||
|
||||
# Use FOR UPDATE to prevent race conditions
|
||||
query = query.with_for_update(nowait=False)
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(query)
|
||||
existing_tool = result.scalar()
|
||||
return existing_tool is not None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_tools_async(
|
||||
@@ -1000,6 +1076,17 @@ class ToolManager:
|
||||
if name_conflict:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
|
||||
# Check for tool name conflicts across projects (global vs project-scoped)
|
||||
# This prevents name collisions between global tools and project-scoped tools
|
||||
cross_project_conflict = await self._check_tool_name_conflict_across_projects_async(
|
||||
tool_name=new_name,
|
||||
project_id=current_tool.project_id,
|
||||
actor=actor,
|
||||
exclude_tool_id=tool_id,
|
||||
)
|
||||
if cross_project_conflict:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
|
||||
# Fetch the tool by ID
|
||||
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
||||
|
||||
@@ -1220,10 +1307,10 @@ class ToolManager:
|
||||
else:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id", "project_id"], set_=update_dict)
|
||||
else:
|
||||
# on conflict, do nothing (skip existing tools)
|
||||
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"])
|
||||
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id", "project_id"])
|
||||
|
||||
await session.execute(upsert_stmt)
|
||||
await session.commit()
|
||||
|
||||
Reference in New Issue
Block a user