fix: make toolset switching atomic to prevent tool desync race (#648)

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
Charles Packer
2026-01-22 17:33:14 -08:00
committed by GitHub
parent ebe3a344f1
commit e32b10f931
3 changed files with 260 additions and 116 deletions

View File

@@ -8,7 +8,10 @@ import type {
ApprovalCreate,
LettaStreamingResponse,
} from "@letta-ai/letta-client/resources/agents/messages";
import { getClientToolsFromRegistry } from "../tools/manager";
import {
getClientToolsFromRegistry,
waitForToolsetReady,
} from "../tools/manager";
import { isTimingsEnabled } from "../utils/timing";
import { getClient } from "./client";
@@ -42,6 +45,10 @@ export async function sendMessageStream(
const client = await getClient();
// Wait for any in-progress toolset switch to complete before reading tools
// This prevents sending messages with stale tools during a switch
await waitForToolsetReady();
let stream: Stream<LettaStreamingResponse>;
if (process.env.DEBUG) {

View File

@@ -229,24 +229,96 @@ export type ToolExecutionResult = {
type ToolRegistry = Map<string, ToolDefinition>;
// Use globalThis to ensure singleton across bundle
// This prevents Bun's bundler from creating duplicate instances of the registry
// Use globalThis to ensure singleton across bundle duplicates
// This prevents Bun's bundler from creating duplicate instances
const REGISTRY_KEY = Symbol.for("@letta/toolRegistry");
const SWITCH_LOCK_KEY = Symbol.for("@letta/toolSwitchLock");
type GlobalWithRegistry = typeof globalThis & {
[key: symbol]: ToolRegistry;
interface SwitchLockState {
promise: Promise<void> | null;
resolve: (() => void) | null;
refCount: number; // Ref-counted to handle overlapping switches
}
type GlobalWithToolState = typeof globalThis & {
[REGISTRY_KEY]?: ToolRegistry;
[SWITCH_LOCK_KEY]?: SwitchLockState;
};
function getRegistry(): ToolRegistry {
const global = globalThis as GlobalWithRegistry;
const global = globalThis as GlobalWithToolState;
if (!global[REGISTRY_KEY]) {
global[REGISTRY_KEY] = new Map();
}
return global[REGISTRY_KEY];
}
function getSwitchLock(): SwitchLockState {
const global = globalThis as GlobalWithToolState;
if (!global[SWITCH_LOCK_KEY]) {
global[SWITCH_LOCK_KEY] = { promise: null, resolve: null, refCount: 0 };
}
return global[SWITCH_LOCK_KEY];
}
const toolRegistry = getRegistry();
/**
* Acquires the toolset switch lock. Call before starting async tool loading.
* Ref-counted: multiple overlapping switches will keep the lock held until all complete.
* Any calls to waitForToolsetReady() will block until all switches finish.
*/
function acquireSwitchLock(): void {
const lock = getSwitchLock();
lock.refCount++;
// Only create a new promise if this is the first acquirer
if (lock.refCount === 1) {
lock.promise = new Promise((resolve) => {
lock.resolve = resolve;
});
}
}
/**
* Releases the toolset switch lock. Call after atomic registry swap completes.
* Only actually releases when all acquirers have released (ref-count drops to 0).
*/
function releaseSwitchLock(): void {
const lock = getSwitchLock();
if (lock.refCount > 0) {
lock.refCount--;
}
// Only resolve when all switches are done
if (lock.refCount === 0 && lock.resolve) {
lock.resolve();
lock.promise = null;
lock.resolve = null;
}
}
/**
* Waits for any in-progress toolset switch to complete.
* Call this before reading from the registry to ensure you get the final toolset.
* Returns immediately if no switch is in progress.
*/
export async function waitForToolsetReady(): Promise<void> {
const lock = getSwitchLock();
if (lock.promise) {
await lock.promise;
}
}
/**
* Checks if a toolset switch is currently in progress.
* Useful for synchronous checks where awaiting isn't possible.
*/
export function isToolsetSwitchInProgress(): boolean {
return getSwitchLock().refCount > 0;
}
/**
* Resolve a server/visible tool name to an internal tool name
* based on the currently loaded toolset.
@@ -376,46 +448,79 @@ export async function analyzeToolApproval(
return analyzeApprovalContext(toolName, toolArgs, workingDirectory);
}
/**
* Atomically replaces the tool registry contents.
* This ensures no intermediate state where registry is empty or partial.
*
* @param newTools - Map of tools to replace the registry with
*/
function replaceRegistry(newTools: ToolRegistry): void {
// Single sync block - no awaits, no yields, no interleaving possible
toolRegistry.clear();
for (const [key, value] of newTools) {
toolRegistry.set(key, value);
}
}
/**
* Loads specific tools by name into the registry.
* Used when resuming an agent to load only the tools attached to that agent.
*
* Acquires the toolset switch lock during loading to prevent message sends from
* reading stale tools. Callers should use waitForToolsetReady() before sending messages.
*
* @param toolNames - Array of specific tool names to load
*/
export async function loadSpecificTools(toolNames: string[]): Promise<void> {
for (const name of toolNames) {
// Skip if tool filter is active and this tool is not enabled
// Acquire lock to signal that a switch is in progress
acquireSwitchLock();
try {
// Import filter once, outside the loop (avoids repeated async yields)
const { toolFilter } = await import("./filter");
if (!toolFilter.isEnabled(name)) {
continue;
// Build new registry in a temporary map (all async work happens here)
const newRegistry: ToolRegistry = new Map();
for (const name of toolNames) {
// Skip if tool filter is active and this tool is not enabled
if (!toolFilter.isEnabled(name)) {
continue;
}
// Map server-facing name to our internal tool name
const internalName = getInternalToolName(name);
const definition = TOOL_DEFINITIONS[internalName as ToolName];
if (!definition) {
console.warn(
`Tool ${name} (internal: ${internalName}) not found in definitions, skipping`,
);
continue;
}
if (!definition.impl) {
throw new Error(`Tool implementation not found for ${internalName}`);
}
const toolSchema: ToolSchema = {
name: internalName,
description: definition.description,
input_schema: definition.schema,
};
// Add to temporary registry
newRegistry.set(internalName, {
schema: toolSchema,
fn: definition.impl,
});
}
// Map server-facing name to our internal tool name
const internalName = getInternalToolName(name);
const definition = TOOL_DEFINITIONS[internalName as ToolName];
if (!definition) {
console.warn(
`Tool ${name} (internal: ${internalName}) not found in definitions, skipping`,
);
continue;
}
if (!definition.impl) {
throw new Error(`Tool implementation not found for ${internalName}`);
}
const toolSchema: ToolSchema = {
name: internalName,
description: definition.description,
input_schema: definition.schema,
};
// Register under the internal name so later lookups using mapping succeed
toolRegistry.set(internalName, {
schema: toolSchema,
fn: definition.impl,
});
// Atomic swap - no yields between clear and populate
replaceRegistry(newRegistry);
} finally {
// Always release the lock, even if an error occurred
releaseSwitchLock();
}
}
@@ -424,95 +529,112 @@ export async function loadSpecificTools(toolNames: string[]): Promise<void> {
* This should be called on program startup.
* Will error if any expected tool files are missing.
*
* Acquires the toolset switch lock during loading to prevent message sends from
* reading stale tools. Callers should use waitForToolsetReady() before sending messages.
*
* @returns Promise that resolves when all tools are loaded
*/
export async function loadTools(modelIdentifier?: string): Promise<void> {
const { toolFilter } = await import("./filter");
// Acquire lock to signal that a switch is in progress
acquireSwitchLock();
// Get all subagents (built-in + custom) to inject into Task description
const allSubagentConfigs = await getAllSubagentConfigs();
const discoveredSubagents = Object.entries(allSubagentConfigs).map(
([name, config]) => ({
name,
description: config.description,
recommendedModel: config.recommendedModel,
}),
);
const filterActive = toolFilter.isActive();
try {
const { toolFilter } = await import("./filter");
let baseToolNames: ToolName[];
if (!filterActive && modelIdentifier && isGeminiModel(modelIdentifier)) {
baseToolNames = GEMINI_PASCAL_TOOLS;
} else if (
!filterActive &&
modelIdentifier &&
isOpenAIModel(modelIdentifier)
) {
baseToolNames = OPENAI_PASCAL_TOOLS;
} else if (!filterActive) {
baseToolNames = ANTHROPIC_DEFAULT_TOOLS;
} else {
// When user explicitly sets --tools, respect that and allow any tool name
baseToolNames = TOOL_NAMES;
}
// Get all subagents (built-in + custom) to inject into Task description
const allSubagentConfigs = await getAllSubagentConfigs();
const discoveredSubagents = Object.entries(allSubagentConfigs).map(
([name, config]) => ({
name,
description: config.description,
recommendedModel: config.recommendedModel,
}),
);
const filterActive = toolFilter.isActive();
for (const name of baseToolNames) {
if (!toolFilter.isEnabled(name)) {
continue;
let baseToolNames: ToolName[];
if (!filterActive && modelIdentifier && isGeminiModel(modelIdentifier)) {
baseToolNames = GEMINI_PASCAL_TOOLS;
} else if (
!filterActive &&
modelIdentifier &&
isOpenAIModel(modelIdentifier)
) {
baseToolNames = OPENAI_PASCAL_TOOLS;
} else if (!filterActive) {
baseToolNames = ANTHROPIC_DEFAULT_TOOLS;
} else {
// When user explicitly sets --tools, respect that and allow any tool name
baseToolNames = TOOL_NAMES;
}
try {
const definition = TOOL_DEFINITIONS[name];
if (!definition) {
throw new Error(`Missing tool definition for ${name}`);
// Build new registry in a temporary map (all async work happens above)
const newRegistry: ToolRegistry = new Map();
for (const name of baseToolNames) {
if (!toolFilter.isEnabled(name)) {
continue;
}
if (!definition.impl) {
throw new Error(`Tool implementation not found for ${name}`);
}
try {
const definition = TOOL_DEFINITIONS[name];
if (!definition) {
throw new Error(`Missing tool definition for ${name}`);
}
// For Task tool, inject discovered subagent descriptions
let description = definition.description;
if (name === "Task" && discoveredSubagents.length > 0) {
description = injectSubagentsIntoTaskDescription(
if (!definition.impl) {
throw new Error(`Tool implementation not found for ${name}`);
}
// For Task tool, inject discovered subagent descriptions
let description = definition.description;
if (name === "Task" && discoveredSubagents.length > 0) {
description = injectSubagentsIntoTaskDescription(
description,
discoveredSubagents,
);
}
const toolSchema: ToolSchema = {
name,
description,
discoveredSubagents,
input_schema: definition.schema,
};
newRegistry.set(name, {
schema: toolSchema,
fn: definition.impl,
});
} catch (error) {
const message =
error instanceof Error ? error.message : JSON.stringify(error);
throw new Error(
`Required tool "${name}" could not be loaded from bundled assets. ${message}`,
);
}
const toolSchema: ToolSchema = {
name,
description,
input_schema: definition.schema,
};
toolRegistry.set(name, {
schema: toolSchema,
fn: definition.impl,
});
} catch (error) {
const message =
error instanceof Error ? error.message : JSON.stringify(error);
throw new Error(
`Required tool "${name}" could not be loaded from bundled assets. ${message}`,
);
}
}
// If LSP is enabled, swap Read with LSP-enhanced version
if (process.env.LETTA_ENABLE_LSP && toolRegistry.has("Read")) {
const lspDefinition = TOOL_DEFINITIONS.ReadLSP;
if (lspDefinition) {
// Replace Read with ReadLSP (but keep the name "Read" for the agent)
toolRegistry.set("Read", {
schema: {
name: "Read", // Keep the tool name as "Read" for the agent
description: lspDefinition.description,
input_schema: lspDefinition.schema,
},
fn: lspDefinition.impl,
});
// If LSP is enabled, swap Read with LSP-enhanced version
if (process.env.LETTA_ENABLE_LSP && newRegistry.has("Read")) {
const lspDefinition = TOOL_DEFINITIONS.ReadLSP;
if (lspDefinition) {
// Replace Read with ReadLSP (but keep the name "Read" for the agent)
newRegistry.set("Read", {
schema: {
name: "Read", // Keep the tool name as "Read" for the agent
description: lspDefinition.description,
input_schema: lspDefinition.schema,
},
fn: lspDefinition.impl,
});
}
}
// Atomic swap - no yields between clear and populate
replaceRegistry(newRegistry);
} finally {
// Always release the lock, even if an error occurred
releaseSwitchLock();
}
}
@@ -936,3 +1058,17 @@ export function getToolSchema(name: string): ToolSchema | undefined {
export function clearTools(): void {
toolRegistry.clear();
}
/**
* Clears the tool registry with lock protection.
* Acquires the switch lock, clears the registry, then releases the lock.
* This ensures sendMessageStream() waits for the clear to complete.
*/
export function clearToolsWithLock(): void {
acquireSwitchLock();
try {
toolRegistry.clear();
} finally {
releaseSwitchLock();
}
}

View File

@@ -2,7 +2,7 @@ import { getClient } from "../agent/client";
import { resolveModel } from "../agent/model";
import { toolFilter } from "./filter";
import {
clearTools,
clearToolsWithLock,
GEMINI_PASCAL_TOOLS,
getToolNames,
isOpenAIModel,
@@ -121,14 +121,14 @@ export async function forceToolsetSwitch(
toolsetName: ToolsetName,
agentId: string,
): Promise<void> {
// Clear currently loaded tools
clearTools();
// Load the appropriate toolset
// Map toolset name to a model identifier for loading
// Note: loadTools/loadSpecificTools acquire a switch lock that causes
// sendMessageStream to wait, preventing messages from being sent with
// stale or partial tools during the switch.
let modelForLoading: string;
if (toolsetName === "none") {
// Just clear tools, no loading needed
// Clear tools with lock protection so sendMessageStream() waits
clearToolsWithLock();
return;
} else if (toolsetName === "codex") {
await loadSpecificTools([...CODEX_TOOLS]);
@@ -174,8 +174,9 @@ export async function switchToolsetForModel(
// Resolve model ID to handle when possible so provider checks stay consistent
const resolvedModel = resolveModel(modelIdentifier) ?? modelIdentifier;
// Clear currently loaded tools and load the appropriate set for the target model
clearTools();
// Load the appropriate set for the target model
// Note: loadTools acquires a switch lock that causes sendMessageStream to wait,
// preventing messages from being sent with stale or partial tools during the switch.
await loadTools(resolvedModel);
// If no tools were loaded (e.g., unexpected handle or edge-case filter),