feat: add index and concurrency control for tools (#6547)
This commit is contained in:
committed by
Caren Thomas
parent
09c027692f
commit
a2d3011d84
@@ -0,0 +1,37 @@
|
||||
"""add tool indexes for organization_id
|
||||
|
||||
Revision ID: af842aa6f743
|
||||
Revises: 175dd10fb916
|
||||
Create Date: 2025-12-07 15:30:43.407495
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "af842aa6f743"
|
||||
down_revision: Union[str, None] = "175dd10fb916"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("idx_messages_step_id"), table_name="messages")
|
||||
op.drop_index(op.f("ix_step_metrics_run_id"), table_name="step_metrics")
|
||||
op.create_index("ix_tools_organization_id", "tools", ["organization_id"], unique=False)
|
||||
op.create_index("ix_tools_organization_id_name", "tools", ["organization_id", "name"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_tools_organization_id_name", table_name="tools")
|
||||
op.drop_index("ix_tools_organization_id", table_name="tools")
|
||||
op.create_index(op.f("ix_step_metrics_run_id"), "step_metrics", ["run_id"], unique=False)
|
||||
op.create_index(op.f("idx_messages_step_id"), "messages", ["step_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
@@ -30,6 +30,8 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
__table_args__ = (
|
||||
UniqueConstraint("name", "organization_id", name="uix_name_organization"),
|
||||
Index("ix_tools_created_at_name", "created_at", "name"),
|
||||
Index("ix_tools_organization_id", "organization_id"),
|
||||
Index("ix_tools_organization_id_name", "organization_id", "name"),
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
|
||||
|
||||
@@ -200,7 +200,11 @@ class ToolManager:
|
||||
async def create_or_update_tool_async(
|
||||
self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False, modal_sandbox_enabled: bool = False
|
||||
) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
"""Create a new tool based on the ToolCreate schema.
|
||||
|
||||
Uses atomic PostgreSQL ON CONFLICT DO UPDATE to prevent race conditions
|
||||
during concurrent upserts.
|
||||
"""
|
||||
from letta.otel.tracing import tracer
|
||||
|
||||
if pydantic_tool.tool_type == ToolType.CUSTOM and not pydantic_tool.json_schema:
|
||||
@@ -228,7 +232,11 @@ class ToolManager:
|
||||
source_code=pydantic_tool.source_code,
|
||||
)
|
||||
|
||||
# check if the tool name already exists
|
||||
# Use atomic PostgreSQL upsert if available
|
||||
if settings.letta_pg_uri_no_default:
|
||||
return await self._atomic_upsert_tool_postgresql(pydantic_tool, actor, modal_sandbox_enabled)
|
||||
|
||||
# Fallback for SQLite: use non-atomic check-then-act pattern
|
||||
current_tool = await self.get_tool_by_name_async(tool_name=pydantic_tool.name, actor=actor)
|
||||
if current_tool:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
@@ -261,6 +269,95 @@ class ToolManager:
|
||||
|
||||
return await self.create_tool_async(pydantic_tool, actor=actor, modal_sandbox_enabled=modal_sandbox_enabled)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def _atomic_upsert_tool_postgresql(
|
||||
self, pydantic_tool: PydanticTool, actor: PydanticUser, modal_sandbox_enabled: bool = False
|
||||
) -> PydanticTool:
|
||||
"""Atomically upsert a single tool using PostgreSQL's ON CONFLICT DO UPDATE.
|
||||
|
||||
This prevents race conditions when multiple concurrent requests try to
|
||||
create/update the same tool by name.
|
||||
"""
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None and pydantic_tool.json_schema:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
|
||||
# Add sandbox:modal to metadata if flag is enabled
|
||||
if modal_sandbox_enabled:
|
||||
if pydantic_tool.metadata_ is None:
|
||||
pydantic_tool.metadata_ = {}
|
||||
pydantic_tool.metadata_["sandbox"] = "modal"
|
||||
|
||||
# Add tool hash to metadata for Modal deployment tracking
|
||||
tool_hash = compute_tool_hash(pydantic_tool)
|
||||
if pydantic_tool.metadata_ is None:
|
||||
pydantic_tool.metadata_ = {}
|
||||
pydantic_tool.metadata_["tool_hash"] = tool_hash
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
table = ToolModel.__table__
|
||||
valid_columns = {col.name for col in table.columns}
|
||||
|
||||
tool_dict = pydantic_tool.model_dump(to_orm=True)
|
||||
tool_dict["_created_by_id"] = actor.id
|
||||
tool_dict["_last_updated_by_id"] = actor.id
|
||||
tool_dict["organization_id"] = actor.organization_id
|
||||
|
||||
# Filter to only include columns that exist in the table
|
||||
# Also exclude None values to let database defaults apply
|
||||
insert_data = {k: v for k, v in tool_dict.items() if k in valid_columns and v is not None}
|
||||
|
||||
# Build the INSERT ... ON CONFLICT DO UPDATE statement
|
||||
stmt = pg_insert(table).values(**insert_data)
|
||||
|
||||
# On conflict, update all columns except id, created_at, and _created_by_id
|
||||
excluded = stmt.excluded
|
||||
update_dict = {}
|
||||
for col in table.columns:
|
||||
if col.name not in ("id", "created_at", "_created_by_id"):
|
||||
if col.name == "updated_at":
|
||||
update_dict[col.name] = func.now()
|
||||
elif col.name == "tags" and (insert_data["tags"] is None or len(insert_data["tags"]) == 0):
|
||||
# TODO: intentional bug to avoid overriding with empty tags on every upsert
|
||||
# means you cannot clear tags, only override them
|
||||
if insert_data["tags"] is None or len(insert_data["tags"]) == 0:
|
||||
continue
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
else:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict).returning(table.c.id)
|
||||
|
||||
result = await session.execute(upsert_stmt)
|
||||
tool_id = result.scalar_one()
|
||||
await session.commit()
|
||||
|
||||
# Fetch the upserted tool
|
||||
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
||||
upserted_tool = tool.to_pydantic()
|
||||
|
||||
# Deploy Modal app if needed (both Modal credentials configured AND tool metadata must indicate Modal)
|
||||
# TODO: dont have such duplicated code
|
||||
tool_requests_modal = upserted_tool.metadata_ and upserted_tool.metadata_.get("sandbox") == "modal"
|
||||
modal_configured = tool_settings.modal_sandbox_enabled
|
||||
|
||||
if upserted_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
|
||||
await self.create_or_update_modal_app(upserted_tool, actor)
|
||||
|
||||
# Embed tool in Turbopuffer if enabled
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
|
||||
|
||||
if should_use_tpuf_for_tools():
|
||||
fire_and_forget(
|
||||
self._embed_tool_background(upserted_tool, actor),
|
||||
task_name=f"embed_tool_{upserted_tool.id}",
|
||||
)
|
||||
|
||||
return upserted_tool
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_server(
|
||||
self, server_config: Union[StdioServerConfig, SSEServerConfig], actor: PydanticUser
|
||||
@@ -336,24 +433,25 @@ class ToolManager:
|
||||
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
|
||||
created_tool = tool.to_pydantic()
|
||||
|
||||
# Deploy Modal app for the new tool
|
||||
# Both Modal credentials configured AND tool metadata must indicate Modal
|
||||
tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal"
|
||||
modal_configured = tool_settings.modal_sandbox_enabled
|
||||
# TODO: dont have such duplicated code
|
||||
# Deploy Modal app for the new tool
|
||||
# Both Modal credentials configured AND tool metadata must indicate Modal
|
||||
tool_requests_modal = created_tool.metadata_ and created_tool.metadata_.get("sandbox") == "modal"
|
||||
modal_configured = tool_settings.modal_sandbox_enabled
|
||||
|
||||
if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
|
||||
await self.create_or_update_modal_app(created_tool, actor)
|
||||
if created_tool.tool_type == ToolType.CUSTOM and tool_requests_modal and modal_configured:
|
||||
await self.create_or_update_modal_app(created_tool, actor)
|
||||
|
||||
# Embed tool in Turbopuffer if enabled
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
|
||||
# Embed tool in Turbopuffer if enabled
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_tools
|
||||
|
||||
if should_use_tpuf_for_tools():
|
||||
fire_and_forget(
|
||||
self._embed_tool_background(created_tool, actor),
|
||||
task_name=f"embed_tool_{created_tool.id}",
|
||||
)
|
||||
if should_use_tpuf_for_tools():
|
||||
fire_and_forget(
|
||||
self._embed_tool_background(created_tool, actor),
|
||||
task_name=f"embed_tool_{created_tool.id}",
|
||||
)
|
||||
|
||||
return created_tool
|
||||
return created_tool
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -470,6 +568,37 @@ class ToolManager:
|
||||
count = result.scalar()
|
||||
return count > 0
|
||||
|
||||
@enforce_types
|
||||
async def _check_tool_name_conflict_with_lock_async(self, session, tool_name: str, exclude_tool_id: str, actor: PydanticUser) -> bool:
|
||||
"""Check if a tool with the given name exists (excluding the current tool), with row locking.
|
||||
|
||||
Uses SELECT FOR UPDATE to prevent race conditions when two concurrent updates
|
||||
try to rename tools to the same name.
|
||||
|
||||
Args:
|
||||
session: The database session (must be part of an active transaction)
|
||||
tool_name: The name to check for conflicts
|
||||
exclude_tool_id: The ID of the current tool being updated (to exclude from check)
|
||||
actor: The user performing the action
|
||||
|
||||
Returns:
|
||||
True if a conflicting tool exists, False otherwise
|
||||
"""
|
||||
# Use SELECT FOR UPDATE to lock any existing row with this name
|
||||
# This prevents another concurrent transaction from also checking and then updating
|
||||
query = (
|
||||
select(ToolModel.id)
|
||||
.where(
|
||||
ToolModel.name == tool_name,
|
||||
ToolModel.organization_id == actor.organization_id,
|
||||
ToolModel.id != exclude_tool_id,
|
||||
)
|
||||
.with_for_update(nowait=False) # Wait for lock if another transaction holds it
|
||||
)
|
||||
result = await session.execute(query)
|
||||
existing_tool = result.scalar()
|
||||
return existing_tool is not None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_tools_async(
|
||||
@@ -786,11 +915,8 @@ class ToolManager:
|
||||
# f"JSON schema name '{new_name}' conflicts with current tool name '{current_tool.name}'. Update the name field explicitly if you want to rename the tool."
|
||||
# )
|
||||
|
||||
# If name changes, enforce uniqueness
|
||||
if new_name != current_tool.name:
|
||||
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
|
||||
if name_exists:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
# Track if we need to check name uniqueness (check is done inside session with lock)
|
||||
needs_name_conflict_check = new_name != current_tool.name
|
||||
|
||||
# NOTE: EXTREMELEY HACKY, we need to stop making assumptions about the source_code
|
||||
if "source_code" in update_data and f"def {new_name}" not in update_data.get("source_code", ""):
|
||||
@@ -849,6 +975,18 @@ class ToolManager:
|
||||
|
||||
# Now perform the update within the session
|
||||
async with db_registry.async_session() as session:
|
||||
# Check name uniqueness with lock INSIDE the session to prevent race conditions
|
||||
# This uses SELECT FOR UPDATE to ensure no other transaction can rename to this name
|
||||
if needs_name_conflict_check:
|
||||
name_conflict = await self._check_tool_name_conflict_with_lock_async(
|
||||
session=session,
|
||||
tool_name=new_name,
|
||||
exclude_tool_id=tool_id,
|
||||
actor=actor,
|
||||
)
|
||||
if name_conflict:
|
||||
raise LettaToolNameConflictError(tool_name=new_name)
|
||||
|
||||
# Fetch the tool by ID
|
||||
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user