diff --git a/alembic/versions/af842aa6f743_add_tool_indexes_for_organization_id.py b/alembic/versions/af842aa6f743_add_tool_indexes_for_organization_id.py new file mode 100644 index 00000000..bf44b47c --- /dev/null +++ b/alembic/versions/af842aa6f743_add_tool_indexes_for_organization_id.py @@ -0,0 +1,37 @@ +"""add tool indexes for organization_id + +Revision ID: af842aa6f743 +Revises: 175dd10fb916 +Create Date: 2025-12-07 15:30:43.407495 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "af842aa6f743" +down_revision: Union[str, None] = "175dd10fb916" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("idx_messages_step_id"), table_name="messages") + op.drop_index(op.f("ix_step_metrics_run_id"), table_name="step_metrics") + op.create_index("ix_tools_organization_id", "tools", ["organization_id"], unique=False) + op.create_index("ix_tools_organization_id_name", "tools", ["organization_id", "name"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_tools_organization_id_name", table_name="tools") + op.drop_index("ix_tools_organization_id", table_name="tools") + op.create_index(op.f("ix_step_metrics_run_id"), "step_metrics", ["run_id"], unique=False) + op.create_index(op.f("idx_messages_step_id"), "messages", ["step_id"], unique=False) + # ### end Alembic commands ### diff --git a/letta/orm/tool.py b/letta/orm/tool.py index c6c8d823..10274bba 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -30,6 +30,8 @@ class Tool(SqlalchemyBase, OrganizationMixin): __table_args__ = ( UniqueConstraint("name", "organization_id", name="uix_name_organization"), Index("ix_tools_created_at_name", "created_at", "name"), + Index("ix_tools_organization_id", "organization_id"), + Index("ix_tools_organization_id_name", "organization_id", "name"), ) name: Mapped[str] = mapped_column(doc="The display name of the tool.") diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 0e1a2353..2a755f02 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -200,7 +200,11 @@ class ToolManager: async def create_or_update_tool_async( self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False, modal_sandbox_enabled: bool = False ) -> PydanticTool: - """Create a new tool based on the ToolCreate schema.""" + """Create a new tool based on the ToolCreate schema. + + Uses atomic PostgreSQL ON CONFLICT DO UPDATE to prevent race conditions + during concurrent upserts. + """ from letta.otel.tracing import tracer if pydantic_tool.tool_type == ToolType.CUSTOM and not pydantic_tool.json_schema: @@ -228,7 +232,11 @@ class ToolManager: source_code=pydantic_tool.source_code, ) - # check if the tool name already exists + # Use atomic PostgreSQL upsert if available + if settings.letta_pg_uri_no_default: + return await self._atomic_upsert_tool_postgresql(pydantic_tool, actor, modal_sandbox_enabled) + + # Fallback for SQLite: use non-atomic check-then-act pattern current_tool = await self.get_tool_by_name_async(tool_name=pydantic_tool.name, actor=actor) if current_tool: # Put to dict and remove fields that should not be reset @@ -261,6 +269,95 @@ class ToolManager: return await self.create_tool_async(pydantic_tool, actor=actor, modal_sandbox_enabled=modal_sandbox_enabled) + @enforce_types + @trace_method + async def _atomic_upsert_tool_postgresql( + self, pydantic_tool: PydanticTool, actor: PydanticUser, modal_sandbox_enabled: bool = False + ) -> PydanticTool: + """Atomically upsert a single tool using PostgreSQL's ON CONFLICT DO UPDATE. + + This prevents race conditions when multiple concurrent requests try to + create/update the same tool by name. + """ + from sqlalchemy.dialects.postgresql import insert as pg_insert + + # Auto-generate description if not provided + if pydantic_tool.description is None and pydantic_tool.json_schema: + pydantic_tool.description = pydantic_tool.json_schema.get("description", None) + + # Add sandbox:modal to metadata if flag is enabled + if modal_sandbox_enabled: + if pydantic_tool.metadata_ is None: + pydantic_tool.metadata_ = {} + pydantic_tool.metadata_["sandbox"] = "modal" + + # Add tool hash to metadata for Modal deployment tracking + tool_hash = compute_tool_hash(pydantic_tool) + if pydantic_tool.metadata_ is None: + pydantic_tool.metadata_ = {} + pydantic_tool.metadata_["tool_hash"] = tool_hash + + async with db_registry.async_session() as session: + table = ToolModel.__table__ + valid_columns = {col.name for col in table.columns} + + tool_dict = pydantic_tool.model_dump(to_orm=True) + tool_dict["_created_by_id"] = actor.id + tool_dict["_last_updated_by_id"] = actor.id + tool_dict["organization_id"] = actor.organization_id + + # Filter to only include columns that exist in the table + # Also exclude None values to let database defaults apply + insert_data = {k: v for k, v in tool_dict.items() if k in valid_columns and v is not None} + + # Build the INSERT ... ON CONFLICT DO UPDATE statement + stmt = pg_insert(table).values(**insert_data) + + # On conflict, update all columns except id, created_at, and _created_by_id + excluded = stmt.excluded + update_dict = {} + for col in table.columns: + if col.name not in ("id", "created_at", "_created_by_id"): + if col.name == "updated_at": + update_dict[col.name] = func.now() + elif col.name == "tags" and (insert_data["tags"] is None or len(insert_data["tags"]) == 0): + # TODO: intentional bug to avoid overriding with empty tags on every upsert + # means you cannot clear tags, only override them + if insert_data["tags"] is None or len(insert_data["tags"]) == 0: + continue + update_dict[col.name] = excluded[col.name] + else: + update_dict[col.name] = excluded[col.name] + + upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict).returning(table.c.id) + + result = await session.execute(upsert_stmt) + tool_id = result.scalar_one() + await session.commit() + + # Fetch the upserted tool + tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) + upserted_tool = tool.to_pydantic() + + # Deploy Modal app if needed (both Modal credentials configured AND tool metadata must indicate Modal) + # TODO: dont have such duplicated code + tool_requests_modal = upserted_tool.metadata_ and upserted_tool.metadata_.get("sandbox") == "modal" + modal_configured = tool_settings.modal_sandbox_enabled + + if upserted_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured: + await self.create_or_update_modal_app(upserted_tool, actor) + + # Embed tool in Turbopuffer if enabled + from letta.helpers.tpuf_client import should_use_tpuf_for_tools + + if should_use_tpuf_for_tools(): + fire_and_forget( + self._embed_tool_background(upserted_tool, actor), + task_name=f"embed_tool_{upserted_tool.id}", + ) + + return upserted_tool + @enforce_types async def create_mcp_server( self, server_config: Union[StdioServerConfig, SSEServerConfig], actor: PydanticUser @@ -336,24 +433,25 @@ class ToolManager: await tool.create_async(session, actor=actor) # Re-raise other database-related errors created_tool = tool.to_pydantic() - # Deploy Modal app for the new tool - # Both Modal credentials configured AND tool metadata must indicate Modal - tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal" - modal_configured = tool_settings.modal_sandbox_enabled + # TODO: dont have such duplicated code + # Deploy Modal app for the new tool + # Both Modal credentials configured AND tool metadata must indicate Modal + tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal" + modal_configured = tool_settings.modal_sandbox_enabled - if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured: - await self.create_or_update_modal_app(created_tool, actor) + if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured: + await self.create_or_update_modal_app(created_tool, actor) - # Embed tool in Turbopuffer if enabled - from letta.helpers.tpuf_client import should_use_tpuf_for_tools + # Embed tool in Turbopuffer if enabled + from letta.helpers.tpuf_client import should_use_tpuf_for_tools - if should_use_tpuf_for_tools(): - fire_and_forget( - self._embed_tool_background(created_tool, actor), - task_name=f"embed_tool_{created_tool.id}", - ) + if should_use_tpuf_for_tools(): + fire_and_forget( + self._embed_tool_background(created_tool, actor), + task_name=f"embed_tool_{created_tool.id}", + ) - return created_tool + return created_tool @enforce_types @trace_method @@ -470,6 +568,37 @@ class ToolManager: count = result.scalar() return count > 0 + @enforce_types + async def _check_tool_name_conflict_with_lock_async(self, session, tool_name: str, exclude_tool_id: str, actor: PydanticUser) -> bool: + """Check if a tool with the given name exists (excluding the current tool), with row locking. + + Uses SELECT FOR UPDATE to prevent race conditions when two concurrent updates + try to rename tools to the same name. + + Args: + session: The database session (must be part of an active transaction) + tool_name: The name to check for conflicts + exclude_tool_id: The ID of the current tool being updated (to exclude from check) + actor: The user performing the action + + Returns: + True if a conflicting tool exists, False otherwise + """ + # Use SELECT FOR UPDATE to lock any existing row with this name + # This prevents another concurrent transaction from also checking and then updating + query = ( + select(ToolModel.id) + .where( + ToolModel.name == tool_name, + ToolModel.organization_id == actor.organization_id, + ToolModel.id != exclude_tool_id, + ) + .with_for_update(nowait=False) # Wait for lock if another transaction holds it + ) + result = await session.execute(query) + existing_tool = result.scalar() + return existing_tool is not None + @enforce_types @trace_method async def list_tools_async( @@ -786,11 +915,8 @@ class ToolManager: # f"JSON schema name '{new_name}' conflicts with current tool name '{current_tool.name}'. Update the name field explicitly if you want to rename the tool." # ) - # If name changes, enforce uniqueness - if new_name != current_tool.name: - name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor) - if name_exists: - raise LettaToolNameConflictError(tool_name=new_name) + # Track if we need to check name uniqueness (check is done inside session with lock) + needs_name_conflict_check = new_name != current_tool.name # NOTE: EXTREMELEY HACKY, we need to stop making assumptions about the source_code if "source_code" in update_data and f"def {new_name}" not in update_data.get("source_code", ""): @@ -849,6 +975,18 @@ class ToolManager: # Now perform the update within the session async with db_registry.async_session() as session: + # Check name uniqueness with lock INSIDE the session to prevent race conditions + # This uses SELECT FOR UPDATE to ensure no other transaction can rename to this name + if needs_name_conflict_check: + name_conflict = await self._check_tool_name_conflict_with_lock_async( + session=session, + tool_name=new_name, + exclude_tool_id=tool_id, + actor=actor, + ) + if name_conflict: + raise LettaToolNameConflictError(tool_name=new_name) + # Fetch the tool by ID tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)