diff --git a/src/agent/message.ts b/src/agent/message.ts index 636d3e3..f964430 100644 --- a/src/agent/message.ts +++ b/src/agent/message.ts @@ -8,7 +8,10 @@ import type { ApprovalCreate, LettaStreamingResponse, } from "@letta-ai/letta-client/resources/agents/messages"; -import { getClientToolsFromRegistry } from "../tools/manager"; +import { + getClientToolsFromRegistry, + waitForToolsetReady, +} from "../tools/manager"; import { isTimingsEnabled } from "../utils/timing"; import { getClient } from "./client"; @@ -42,6 +45,10 @@ export async function sendMessageStream( const client = await getClient(); + // Wait for any in-progress toolset switch to complete before reading tools + // This prevents sending messages with stale tools during a switch + await waitForToolsetReady(); + let stream: Stream; if (process.env.DEBUG) { diff --git a/src/tools/manager.ts b/src/tools/manager.ts index 3764712..0c34986 100644 --- a/src/tools/manager.ts +++ b/src/tools/manager.ts @@ -229,24 +229,96 @@ export type ToolExecutionResult = { type ToolRegistry = Map; -// Use globalThis to ensure singleton across bundle -// This prevents Bun's bundler from creating duplicate instances of the registry +// Use globalThis to ensure singleton across bundle duplicates +// This prevents Bun's bundler from creating duplicate instances const REGISTRY_KEY = Symbol.for("@letta/toolRegistry"); +const SWITCH_LOCK_KEY = Symbol.for("@letta/toolSwitchLock"); -type GlobalWithRegistry = typeof globalThis & { - [key: symbol]: ToolRegistry; +interface SwitchLockState { + promise: Promise | null; + resolve: (() => void) | null; + refCount: number; // Ref-counted to handle overlapping switches +} + +type GlobalWithToolState = typeof globalThis & { + [REGISTRY_KEY]?: ToolRegistry; + [SWITCH_LOCK_KEY]?: SwitchLockState; }; function getRegistry(): ToolRegistry { - const global = globalThis as GlobalWithRegistry; + const global = globalThis as GlobalWithToolState; if (!global[REGISTRY_KEY]) { global[REGISTRY_KEY] = new Map(); } return global[REGISTRY_KEY]; } +function getSwitchLock(): SwitchLockState { + const global = globalThis as GlobalWithToolState; + if (!global[SWITCH_LOCK_KEY]) { + global[SWITCH_LOCK_KEY] = { promise: null, resolve: null, refCount: 0 }; + } + return global[SWITCH_LOCK_KEY]; +} + const toolRegistry = getRegistry(); +/** + * Acquires the toolset switch lock. Call before starting async tool loading. + * Ref-counted: multiple overlapping switches will keep the lock held until all complete. + * Any calls to waitForToolsetReady() will block until all switches finish. + */ +function acquireSwitchLock(): void { + const lock = getSwitchLock(); + lock.refCount++; + + // Only create a new promise if this is the first acquirer + if (lock.refCount === 1) { + lock.promise = new Promise((resolve) => { + lock.resolve = resolve; + }); + } +} + +/** + * Releases the toolset switch lock. Call after atomic registry swap completes. + * Only actually releases when all acquirers have released (ref-count drops to 0). + */ +function releaseSwitchLock(): void { + const lock = getSwitchLock(); + + if (lock.refCount > 0) { + lock.refCount--; + } + + // Only resolve when all switches are done + if (lock.refCount === 0 && lock.resolve) { + lock.resolve(); + lock.promise = null; + lock.resolve = null; + } +} + +/** + * Waits for any in-progress toolset switch to complete. + * Call this before reading from the registry to ensure you get the final toolset. + * Returns immediately if no switch is in progress. + */ +export async function waitForToolsetReady(): Promise { + const lock = getSwitchLock(); + if (lock.promise) { + await lock.promise; + } +} + +/** + * Checks if a toolset switch is currently in progress. + * Useful for synchronous checks where awaiting isn't possible. + */ +export function isToolsetSwitchInProgress(): boolean { + return getSwitchLock().refCount > 0; +} + /** * Resolve a server/visible tool name to an internal tool name * based on the currently loaded toolset. @@ -376,46 +448,79 @@ export async function analyzeToolApproval( return analyzeApprovalContext(toolName, toolArgs, workingDirectory); } +/** + * Atomically replaces the tool registry contents. + * This ensures no intermediate state where registry is empty or partial. + * + * @param newTools - Map of tools to replace the registry with + */ +function replaceRegistry(newTools: ToolRegistry): void { + // Single sync block - no awaits, no yields, no interleaving possible + toolRegistry.clear(); + for (const [key, value] of newTools) { + toolRegistry.set(key, value); + } +} + /** * Loads specific tools by name into the registry. * Used when resuming an agent to load only the tools attached to that agent. * + * Acquires the toolset switch lock during loading to prevent message sends from + * reading stale tools. Callers should use waitForToolsetReady() before sending messages. + * * @param toolNames - Array of specific tool names to load */ export async function loadSpecificTools(toolNames: string[]): Promise { - for (const name of toolNames) { - // Skip if tool filter is active and this tool is not enabled + // Acquire lock to signal that a switch is in progress + acquireSwitchLock(); + + try { + // Import filter once, outside the loop (avoids repeated async yields) const { toolFilter } = await import("./filter"); - if (!toolFilter.isEnabled(name)) { - continue; + + // Build new registry in a temporary map (all async work happens here) + const newRegistry: ToolRegistry = new Map(); + + for (const name of toolNames) { + // Skip if tool filter is active and this tool is not enabled + if (!toolFilter.isEnabled(name)) { + continue; + } + + // Map server-facing name to our internal tool name + const internalName = getInternalToolName(name); + + const definition = TOOL_DEFINITIONS[internalName as ToolName]; + if (!definition) { + console.warn( + `Tool ${name} (internal: ${internalName}) not found in definitions, skipping`, + ); + continue; + } + + if (!definition.impl) { + throw new Error(`Tool implementation not found for ${internalName}`); + } + + const toolSchema: ToolSchema = { + name: internalName, + description: definition.description, + input_schema: definition.schema, + }; + + // Add to temporary registry + newRegistry.set(internalName, { + schema: toolSchema, + fn: definition.impl, + }); } - // Map server-facing name to our internal tool name - const internalName = getInternalToolName(name); - - const definition = TOOL_DEFINITIONS[internalName as ToolName]; - if (!definition) { - console.warn( - `Tool ${name} (internal: ${internalName}) not found in definitions, skipping`, - ); - continue; - } - - if (!definition.impl) { - throw new Error(`Tool implementation not found for ${internalName}`); - } - - const toolSchema: ToolSchema = { - name: internalName, - description: definition.description, - input_schema: definition.schema, - }; - - // Register under the internal name so later lookups using mapping succeed - toolRegistry.set(internalName, { - schema: toolSchema, - fn: definition.impl, - }); + // Atomic swap - no yields between clear and populate + replaceRegistry(newRegistry); + } finally { + // Always release the lock, even if an error occurred + releaseSwitchLock(); } } @@ -424,95 +529,112 @@ export async function loadSpecificTools(toolNames: string[]): Promise { * This should be called on program startup. * Will error if any expected tool files are missing. * + * Acquires the toolset switch lock during loading to prevent message sends from + * reading stale tools. Callers should use waitForToolsetReady() before sending messages. + * * @returns Promise that resolves when all tools are loaded */ export async function loadTools(modelIdentifier?: string): Promise { - const { toolFilter } = await import("./filter"); + // Acquire lock to signal that a switch is in progress + acquireSwitchLock(); - // Get all subagents (built-in + custom) to inject into Task description - const allSubagentConfigs = await getAllSubagentConfigs(); - const discoveredSubagents = Object.entries(allSubagentConfigs).map( - ([name, config]) => ({ - name, - description: config.description, - recommendedModel: config.recommendedModel, - }), - ); - const filterActive = toolFilter.isActive(); + try { + const { toolFilter } = await import("./filter"); - let baseToolNames: ToolName[]; - if (!filterActive && modelIdentifier && isGeminiModel(modelIdentifier)) { - baseToolNames = GEMINI_PASCAL_TOOLS; - } else if ( - !filterActive && - modelIdentifier && - isOpenAIModel(modelIdentifier) - ) { - baseToolNames = OPENAI_PASCAL_TOOLS; - } else if (!filterActive) { - baseToolNames = ANTHROPIC_DEFAULT_TOOLS; - } else { - // When user explicitly sets --tools, respect that and allow any tool name - baseToolNames = TOOL_NAMES; - } + // Get all subagents (built-in + custom) to inject into Task description + const allSubagentConfigs = await getAllSubagentConfigs(); + const discoveredSubagents = Object.entries(allSubagentConfigs).map( + ([name, config]) => ({ + name, + description: config.description, + recommendedModel: config.recommendedModel, + }), + ); + const filterActive = toolFilter.isActive(); - for (const name of baseToolNames) { - if (!toolFilter.isEnabled(name)) { - continue; + let baseToolNames: ToolName[]; + if (!filterActive && modelIdentifier && isGeminiModel(modelIdentifier)) { + baseToolNames = GEMINI_PASCAL_TOOLS; + } else if ( + !filterActive && + modelIdentifier && + isOpenAIModel(modelIdentifier) + ) { + baseToolNames = OPENAI_PASCAL_TOOLS; + } else if (!filterActive) { + baseToolNames = ANTHROPIC_DEFAULT_TOOLS; + } else { + // When user explicitly sets --tools, respect that and allow any tool name + baseToolNames = TOOL_NAMES; } - try { - const definition = TOOL_DEFINITIONS[name]; - if (!definition) { - throw new Error(`Missing tool definition for ${name}`); + // Build new registry in a temporary map (all async work happens above) + const newRegistry: ToolRegistry = new Map(); + + for (const name of baseToolNames) { + if (!toolFilter.isEnabled(name)) { + continue; } - if (!definition.impl) { - throw new Error(`Tool implementation not found for ${name}`); - } + try { + const definition = TOOL_DEFINITIONS[name]; + if (!definition) { + throw new Error(`Missing tool definition for ${name}`); + } - // For Task tool, inject discovered subagent descriptions - let description = definition.description; - if (name === "Task" && discoveredSubagents.length > 0) { - description = injectSubagentsIntoTaskDescription( + if (!definition.impl) { + throw new Error(`Tool implementation not found for ${name}`); + } + + // For Task tool, inject discovered subagent descriptions + let description = definition.description; + if (name === "Task" && discoveredSubagents.length > 0) { + description = injectSubagentsIntoTaskDescription( + description, + discoveredSubagents, + ); + } + + const toolSchema: ToolSchema = { + name, description, - discoveredSubagents, + input_schema: definition.schema, + }; + + newRegistry.set(name, { + schema: toolSchema, + fn: definition.impl, + }); + } catch (error) { + const message = + error instanceof Error ? error.message : JSON.stringify(error); + throw new Error( + `Required tool "${name}" could not be loaded from bundled assets. ${message}`, ); } - - const toolSchema: ToolSchema = { - name, - description, - input_schema: definition.schema, - }; - - toolRegistry.set(name, { - schema: toolSchema, - fn: definition.impl, - }); - } catch (error) { - const message = - error instanceof Error ? error.message : JSON.stringify(error); - throw new Error( - `Required tool "${name}" could not be loaded from bundled assets. ${message}`, - ); } - } - // If LSP is enabled, swap Read with LSP-enhanced version - if (process.env.LETTA_ENABLE_LSP && toolRegistry.has("Read")) { - const lspDefinition = TOOL_DEFINITIONS.ReadLSP; - if (lspDefinition) { - // Replace Read with ReadLSP (but keep the name "Read" for the agent) - toolRegistry.set("Read", { - schema: { - name: "Read", // Keep the tool name as "Read" for the agent - description: lspDefinition.description, - input_schema: lspDefinition.schema, - }, - fn: lspDefinition.impl, - }); + // If LSP is enabled, swap Read with LSP-enhanced version + if (process.env.LETTA_ENABLE_LSP && newRegistry.has("Read")) { + const lspDefinition = TOOL_DEFINITIONS.ReadLSP; + if (lspDefinition) { + // Replace Read with ReadLSP (but keep the name "Read" for the agent) + newRegistry.set("Read", { + schema: { + name: "Read", // Keep the tool name as "Read" for the agent + description: lspDefinition.description, + input_schema: lspDefinition.schema, + }, + fn: lspDefinition.impl, + }); + } } + + // Atomic swap - no yields between clear and populate + replaceRegistry(newRegistry); + } finally { + // Always release the lock, even if an error occurred + releaseSwitchLock(); } } @@ -936,3 +1058,17 @@ export function getToolSchema(name: string): ToolSchema | undefined { export function clearTools(): void { toolRegistry.clear(); } + +/** + * Clears the tool registry with lock protection. + * Acquires the switch lock, clears the registry, then releases the lock. + * This ensures sendMessageStream() waits for the clear to complete. + */ +export function clearToolsWithLock(): void { + acquireSwitchLock(); + try { + toolRegistry.clear(); + } finally { + releaseSwitchLock(); + } +} diff --git a/src/tools/toolset.ts b/src/tools/toolset.ts index 6b47d91..66c335c 100644 --- a/src/tools/toolset.ts +++ b/src/tools/toolset.ts @@ -2,7 +2,7 @@ import { getClient } from "../agent/client"; import { resolveModel } from "../agent/model"; import { toolFilter } from "./filter"; import { - clearTools, + clearToolsWithLock, GEMINI_PASCAL_TOOLS, getToolNames, isOpenAIModel, @@ -121,14 +121,14 @@ export async function forceToolsetSwitch( toolsetName: ToolsetName, agentId: string, ): Promise { - // Clear currently loaded tools - clearTools(); - // Load the appropriate toolset - // Map toolset name to a model identifier for loading + // Note: loadTools/loadSpecificTools acquire a switch lock that causes + // sendMessageStream to wait, preventing messages from being sent with + // stale or partial tools during the switch. let modelForLoading: string; if (toolsetName === "none") { - // Just clear tools, no loading needed + // Clear tools with lock protection so sendMessageStream() waits + clearToolsWithLock(); return; } else if (toolsetName === "codex") { await loadSpecificTools([...CODEX_TOOLS]); @@ -174,8 +174,9 @@ export async function switchToolsetForModel( // Resolve model ID to handle when possible so provider checks stay consistent const resolvedModel = resolveModel(modelIdentifier) ?? modelIdentifier; - // Clear currently loaded tools and load the appropriate set for the target model - clearTools(); + // Load the appropriate set for the target model + // Note: loadTools acquires a switch lock that causes sendMessageStream to wait, + // preventing messages from being sent with stale or partial tools during the switch. await loadTools(resolvedModel); // If no tools were loaded (e.g., unexpected handle or edge-case filter),