feat: add index and concurrency control for tools (fixed alembic) (#6552)

This commit is contained in:
Sarah Wooders
2025-12-07 17:08:00 -08:00
committed by Caren Thomas
parent c8c06168e2
commit 4f1fbe45aa
3 changed files with 194 additions and 21 deletions

View File

@@ -0,0 +1,33 @@
"""add tool indexes for organization_id
Revision ID: af842aa6f743
Revises: 175dd10fb916
Create Date: 2025-12-07 15:30:43.407495
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "af842aa6f743"
down_revision: Union[str, None] = "175dd10fb916"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index("ix_tools_organization_id", "tools", ["organization_id"], unique=False)
op.create_index("ix_tools_organization_id_name", "tools", ["organization_id", "name"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f("ix_step_metrics_run_id"), "step_metrics", ["run_id"], unique=False)
op.create_index(op.f("idx_messages_step_id"), "messages", ["step_id"], unique=False)
# ### end Alembic commands ###

View File

@@ -30,6 +30,8 @@ class Tool(SqlalchemyBase, OrganizationMixin):
__table_args__ = (
UniqueConstraint("name", "organization_id", name="uix_name_organization"),
Index("ix_tools_created_at_name", "created_at", "name"),
Index("ix_tools_organization_id", "organization_id"),
Index("ix_tools_organization_id_name", "organization_id", "name"),
)
name: Mapped[str] = mapped_column(doc="The display name of the tool.")

View File

@@ -200,7 +200,11 @@ class ToolManager:
async def create_or_update_tool_async(
self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False, modal_sandbox_enabled: bool = False
) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
"""Create a new tool based on the ToolCreate schema.
Uses atomic PostgreSQL ON CONFLICT DO UPDATE to prevent race conditions
during concurrent upserts.
"""
from letta.otel.tracing import tracer
if pydantic_tool.tool_type == ToolType.CUSTOM and not pydantic_tool.json_schema:
@@ -228,7 +232,11 @@ class ToolManager:
source_code=pydantic_tool.source_code,
)
# check if the tool name already exists
# Use atomic PostgreSQL upsert if available
if settings.letta_pg_uri_no_default:
return await self._atomic_upsert_tool_postgresql(pydantic_tool, actor, modal_sandbox_enabled)
# Fallback for SQLite: use non-atomic check-then-act pattern
current_tool = await self.get_tool_by_name_async(tool_name=pydantic_tool.name, actor=actor)
if current_tool:
# Put to dict and remove fields that should not be reset
@@ -261,6 +269,95 @@ class ToolManager:
return await self.create_tool_async(pydantic_tool, actor=actor, modal_sandbox_enabled=modal_sandbox_enabled)
@enforce_types
@trace_method
async def _atomic_upsert_tool_postgresql(
self, pydantic_tool: PydanticTool, actor: PydanticUser, modal_sandbox_enabled: bool = False
) -> PydanticTool:
"""Atomically upsert a single tool using PostgreSQL's ON CONFLICT DO UPDATE.
This prevents race conditions when multiple concurrent requests try to
create/update the same tool by name.
"""
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Auto-generate description if not provided
if pydantic_tool.description is None and pydantic_tool.json_schema:
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
# Add sandbox:modal to metadata if flag is enabled
if modal_sandbox_enabled:
if pydantic_tool.metadata_ is None:
pydantic_tool.metadata_ = {}
pydantic_tool.metadata_["sandbox"] = "modal"
# Add tool hash to metadata for Modal deployment tracking
tool_hash = compute_tool_hash(pydantic_tool)
if pydantic_tool.metadata_ is None:
pydantic_tool.metadata_ = {}
pydantic_tool.metadata_["tool_hash"] = tool_hash
async with db_registry.async_session() as session:
table = ToolModel.__table__
valid_columns = {col.name for col in table.columns}
tool_dict = pydantic_tool.model_dump(to_orm=True)
tool_dict["_created_by_id"] = actor.id
tool_dict["_last_updated_by_id"] = actor.id
tool_dict["organization_id"] = actor.organization_id
# Filter to only include columns that exist in the table
# Also exclude None values to let database defaults apply
insert_data = {k: v for k, v in tool_dict.items() if k in valid_columns and v is not None}
# Build the INSERT ... ON CONFLICT DO UPDATE statement
stmt = pg_insert(table).values(**insert_data)
# On conflict, update all columns except id, created_at, and _created_by_id
excluded = stmt.excluded
update_dict = {}
for col in table.columns:
if col.name not in ("id", "created_at", "_created_by_id"):
if col.name == "updated_at":
update_dict[col.name] = func.now()
elif col.name == "tags" and (insert_data["tags"] is None or len(insert_data["tags"]) == 0):
# TODO: intentional bug to avoid overriding with empty tags on every upsert
# means you cannot clear tags, only override them
if insert_data["tags"] is None or len(insert_data["tags"]) == 0:
continue
update_dict[col.name] = excluded[col.name]
else:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict).returning(table.c.id)
result = await session.execute(upsert_stmt)
tool_id = result.scalar_one()
await session.commit()
# Fetch the upserted tool
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
upserted_tool = tool.to_pydantic()
# Deploy Modal app if needed (both Modal credentials configured AND tool metadata must indicate Modal)
# TODO: dont have such duplicated code
tool_requests_modal = upserted_tool.metadata_ and upserted_tool.metadata_.get("sandbox") == "modal"
modal_configured = tool_settings.modal_sandbox_enabled
if upserted_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
await self.create_or_update_modal_app(upserted_tool, actor)
# Embed tool in Turbopuffer if enabled
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
if should_use_tpuf_for_tools():
fire_and_forget(
self._embed_tool_background(upserted_tool, actor),
task_name=f"embed_tool_{upserted_tool.id}",
)
return upserted_tool
@enforce_types
async def create_mcp_server(
self, server_config: Union[StdioServerConfig, SSEServerConfig], actor: PydanticUser
@@ -336,24 +433,25 @@ class ToolManager:
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
created_tool = tool.to_pydantic()
# Deploy Modal app for the new tool
# Both Modal credentials configured AND tool metadata must indicate Modal
tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal"
modal_configured = tool_settings.modal_sandbox_enabled
# TODO: dont have such duplicated code
# Deploy Modal app for the new tool
# Both Modal credentials configured AND tool metadata must indicate Modal
tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal"
modal_configured = tool_settings.modal_sandbox_enabled
if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
await self.create_or_update_modal_app(created_tool, actor)
if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
await self.create_or_update_modal_app(created_tool, actor)
# Embed tool in Turbopuffer if enabled
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
# Embed tool in Turbopuffer if enabled
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
if should_use_tpuf_for_tools():
fire_and_forget(
self._embed_tool_background(created_tool, actor),
task_name=f"embed_tool_{created_tool.id}",
)
if should_use_tpuf_for_tools():
fire_and_forget(
self._embed_tool_background(created_tool, actor),
task_name=f"embed_tool_{created_tool.id}",
)
return created_tool
return created_tool
@enforce_types
@trace_method
@@ -470,6 +568,37 @@ class ToolManager:
count = result.scalar()
return count > 0
@enforce_types
async def _check_tool_name_conflict_with_lock_async(self, session, tool_name: str, exclude_tool_id: str, actor: PydanticUser) -> bool:
"""Check if a tool with the given name exists (excluding the current tool), with row locking.
Uses SELECT FOR UPDATE to prevent race conditions when two concurrent updates
try to rename tools to the same name.
Args:
session: The database session (must be part of an active transaction)
tool_name: The name to check for conflicts
exclude_tool_id: The ID of the current tool being updated (to exclude from check)
actor: The user performing the action
Returns:
True if a conflicting tool exists, False otherwise
"""
# Use SELECT FOR UPDATE to lock any existing row with this name
# This prevents another concurrent transaction from also checking and then updating
query = (
select(ToolModel.id)
.where(
ToolModel.name == tool_name,
ToolModel.organization_id == actor.organization_id,
ToolModel.id != exclude_tool_id,
)
.with_for_update(nowait=False) # Wait for lock if another transaction holds it
)
result = await session.execute(query)
existing_tool = result.scalar()
return existing_tool is not None
@enforce_types
@trace_method
async def list_tools_async(
@@ -786,11 +915,8 @@ class ToolManager:
# f"JSON schema name '{new_name}' conflicts with current tool name '{current_tool.name}'. Update the name field explicitly if you want to rename the tool."
# )
# If name changes, enforce uniqueness
if new_name != current_tool.name:
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
if name_exists:
raise LettaToolNameConflictError(tool_name=new_name)
# Track if we need to check name uniqueness (check is done inside session with lock)
needs_name_conflict_check = new_name != current_tool.name
# NOTE: EXTREMELEY HACKY, we need to stop making assumptions about the source_code
if "source_code" in update_data and f"def {new_name}" not in update_data.get("source_code", ""):
@@ -849,6 +975,18 @@ class ToolManager:
# Now perform the update within the session
async with db_registry.async_session() as session:
# Check name uniqueness with lock INSIDE the session to prevent race conditions
# This uses SELECT FOR UPDATE to ensure no other transaction can rename to this name
if needs_name_conflict_check:
name_conflict = await self._check_tool_name_conflict_with_lock_async(
session=session,
tool_name=new_name,
exclude_tool_id=tool_id,
actor=actor,
)
if name_conflict:
raise LettaToolNameConflictError(tool_name=new_name)
# Fetch the tool by ID
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)