diff --git a/src/agent/approval-execution.ts b/src/agent/approval-execution.ts index 8c154a8..8715c34 100644 --- a/src/agent/approval-execution.ts +++ b/src/agent/approval-execution.ts @@ -189,6 +189,7 @@ async function executeSingleDecision( chunk: string, isStderr?: boolean, ) => void; + toolContextId?: string; }, ): Promise { // If aborted, record an interrupted result @@ -245,6 +246,7 @@ async function executeSingleDecision( { signal: options?.abortSignal, toolCallId: decision.approval.toolCallId, + toolContextId: options?.toolContextId, onOutput: options?.onStreamingOutput ? (chunk, stream) => options.onStreamingOutput?.( @@ -357,6 +359,7 @@ export async function executeApprovalBatch( chunk: string, isStderr?: boolean, ) => void; + toolContextId?: string; }, ): Promise { // Pre-allocate results array to maintain original order @@ -452,6 +455,7 @@ export async function executeAutoAllowedTools( chunk: string, isStderr?: boolean, ) => void; + toolContextId?: string; }, ): Promise { const decisions: ApprovalDecision[] = autoAllowed.map((ac) => ({ diff --git a/src/agent/message.ts b/src/agent/message.ts index 0385588..6df142c 100644 --- a/src/agent/message.ts +++ b/src/agent/message.ts @@ -9,14 +9,26 @@ import type { LettaStreamingResponse, } from "@letta-ai/letta-client/resources/agents/messages"; import { - getClientToolsFromRegistry, + captureToolExecutionContext, waitForToolsetReady, } from "../tools/manager"; import { isTimingsEnabled } from "../utils/timing"; import { getClient } from "./client"; -// Symbol to store timing info on the stream object -export const STREAM_REQUEST_START_TIME = Symbol("streamRequestStartTime"); +const streamRequestStartTimes = new WeakMap(); +const streamToolContextIds = new WeakMap(); + +export function getStreamRequestStartTime( + stream: Stream, +): number | undefined { + return streamRequestStartTimes.get(stream as object); +} + +export function getStreamToolContextId( + stream: Stream, +): string | null { + return streamToolContextIds.get(stream as object) ?? null; +} /** * Send a message to a conversation and return a streaming response. @@ -40,14 +52,13 @@ export async function sendMessageStream( // requestOptions: { maxRetries?: number } = { maxRetries: 0 }, requestOptions: { maxRetries?: number } = {}, ): Promise> { - // Capture request start time for TTFT measurement when timings are enabled const requestStartTime = isTimingsEnabled() ? performance.now() : undefined; - const client = await getClient(); // Wait for any in-progress toolset switch to complete before reading tools // This prevents sending messages with stale tools during a switch await waitForToolsetReady(); + const { clientTools, contextId } = captureToolExecutionContext(); let stream: Stream; @@ -71,7 +82,7 @@ export async function sendMessageStream( streaming: true, stream_tokens: opts.streamTokens ?? true, background: opts.background ?? true, - client_tools: getClientToolsFromRegistry(), + client_tools: clientTools, include_compaction_messages: true, }, requestOptions, @@ -85,18 +96,17 @@ export async function sendMessageStream( streaming: true, stream_tokens: opts.streamTokens ?? true, background: opts.background ?? true, - client_tools: getClientToolsFromRegistry(), + client_tools: clientTools, include_compaction_messages: true, }, requestOptions, ); } - // Attach start time to stream for TTFT calculation in drainStream if (requestStartTime !== undefined) { - (stream as unknown as Record)[STREAM_REQUEST_START_TIME] = - requestStartTime; + streamRequestStartTimes.set(stream as object, requestStartTime); } + streamToolContextIds.set(stream as object, contextId); return stream; } diff --git a/src/cli/App.tsx b/src/cli/App.tsx index 5c82946..f8497e5 100644 --- a/src/cli/App.tsx +++ b/src/cli/App.tsx @@ -51,7 +51,7 @@ import { ensureMemoryFilesystemDirs, getMemoryFilesystemRoot, } from "../agent/memoryFilesystem"; -import { sendMessageStream } from "../agent/message"; +import { getStreamToolContextId, sendMessageStream } from "../agent/message"; import { getModelInfo, getModelShortName, @@ -95,6 +95,7 @@ import { analyzeToolApproval, checkToolPermission, executeTool, + releaseToolExecutionContext, savePermissionRule, type ToolExecutionResult, } from "../tools/manager"; @@ -1570,6 +1571,13 @@ export default function App({ const lastSentInputRef = useRef | null>( null, ); + const approvalToolContextIdRef = useRef(null); + const clearApprovalToolContext = useCallback(() => { + const contextId = approvalToolContextIdRef.current; + if (!contextId) return; + approvalToolContextIdRef.current = null; + releaseToolExecutionContext(contextId); + }, []); // Non-null only when the previous turn was explicitly interrupted by the user. // Used to gate recovery alert injection to true user-interrupt retries. const pendingInterruptRecoveryConversationIdRef = useRef(null); @@ -3173,12 +3181,14 @@ export default function App({ // throws before streaming begins, e.g., retry after LLM error when backend // already cleared the approval) let stream: Awaited>; + let turnToolContextId: string | null = null; try { stream = await sendMessageStream( conversationIdRef.current, currentInput, { agentId: agentIdRef.current }, ); + turnToolContextId = getStreamToolContextId(stream); } catch (preStreamError) { // Extract error detail using shared helper (handles nested/direct/message shapes) const errorDetail = extractConflictDetail(preStreamError); @@ -3599,6 +3609,7 @@ export default function App({ // Case 1: Turn ended normally if (stopReasonToHandle === "end_turn") { + clearApprovalToolContext(); setStreaming(false); const liveElapsedMs = (() => { const snapshot = sessionStatsRef.current.getTrajectorySnapshot(); @@ -3775,6 +3786,7 @@ export default function App({ // Case 1.5: Stream was cancelled by user if (stopReasonToHandle === "cancelled") { + clearApprovalToolContext(); setStreaming(false); closeTrajectorySegment(); syncTrajectoryElapsedBase(); @@ -3824,6 +3836,8 @@ export default function App({ // Case 2: Requires approval if (stopReasonToHandle === "requires_approval") { + clearApprovalToolContext(); + approvalToolContextIdRef.current = turnToolContextId; // Clear stale state immediately to prevent ID mismatch bugs setAutoHandledResults([]); setAutoDeniedApprovals([]); @@ -3839,6 +3853,7 @@ export default function App({ : []; if (approvalsToProcess.length === 0) { + clearApprovalToolContext(); appendError( `Unexpected empty approvals with stop reason: ${stopReason}`, ); @@ -3851,6 +3866,7 @@ export default function App({ // If in quietCancel mode (user queued messages), auto-reject all approvals // and send denials + queued messages together if (waitingForQueueCancelRef.current) { + clearApprovalToolContext(); // Create denial results for all approvals const denialResults = approvalsToProcess.map((approvalItem) => ({ type: "approval" as const, @@ -3898,6 +3914,7 @@ export default function App({ userCancelledRef.current || abortControllerRef.current?.signal.aborted ) { + clearApprovalToolContext(); setStreaming(false); closeTrajectorySegment(); syncTrajectoryElapsedBase(); @@ -4034,6 +4051,8 @@ export default function App({ { abortSignal: autoAllowedAbortController.signal, onStreamingOutput: updateStreamingOutput, + toolContextId: + approvalToolContextIdRef.current ?? undefined, }, ) : []; @@ -4744,6 +4763,7 @@ export default function App({ consumeQueuedMessages, appendTaskNotificationEvents, maybeCheckMemoryGitStatus, + clearApprovalToolContext, openTrajectorySegment, syncTrajectoryTokenBase, syncTrajectoryElapsedBase, @@ -5550,6 +5570,7 @@ export default function App({ { abortSignal: autoAllowedAbortController.signal, onStreamingOutput: updateStreamingOutput, + toolContextId: approvalToolContextIdRef.current ?? undefined, }, ); // Map to ApprovalResult format (ToolReturn) @@ -8572,6 +8593,8 @@ ${SYSTEM_REMINDER_CLOSE} { abortSignal: autoAllowedAbortController.signal, onStreamingOutput: updateStreamingOutput, + toolContextId: + approvalToolContextIdRef.current ?? undefined, }, ) : []; @@ -8816,6 +8839,8 @@ ${SYSTEM_REMINDER_CLOSE} { abortSignal: autoAllowedAbortController.signal, onStreamingOutput: updateStreamingOutput, + toolContextId: + approvalToolContextIdRef.current ?? undefined, }, ) : []; @@ -9002,6 +9027,7 @@ ${SYSTEM_REMINDER_CLOSE} if ( !streaming && hasAnythingQueued && + !queuedOverlayAction && // Prioritize queued model/toolset/system switches before dequeuing messages pendingApprovals.length === 0 && !commandRunning && !isExecutingTool && @@ -9035,7 +9061,7 @@ ${SYSTEM_REMINDER_CLOSE} // Log why dequeue was blocked (useful for debugging stuck queues) debugLog( "queue", - `Dequeue blocked: streaming=${streaming}, pendingApprovals=${pendingApprovals.length}, commandRunning=${commandRunning}, isExecutingTool=${isExecutingTool}, anySelectorOpen=${anySelectorOpen}, waitingForQueueCancel=${waitingForQueueCancelRef.current}, userCancelled=${userCancelledRef.current}, abortController=${!!abortControllerRef.current}`, + `Dequeue blocked: streaming=${streaming}, queuedOverlayAction=${!!queuedOverlayAction}, pendingApprovals=${pendingApprovals.length}, commandRunning=${commandRunning}, isExecutingTool=${isExecutingTool}, anySelectorOpen=${anySelectorOpen}, waitingForQueueCancel=${waitingForQueueCancelRef.current}, userCancelled=${userCancelledRef.current}, abortController=${!!abortControllerRef.current}`, ); } }, [ @@ -9045,6 +9071,7 @@ ${SYSTEM_REMINDER_CLOSE} commandRunning, isExecutingTool, anySelectorOpen, + queuedOverlayAction, dequeueEpoch, // Triggered when userCancelledRef is reset OR task notifications added ]); @@ -9155,6 +9182,7 @@ ${SYSTEM_REMINDER_CLOSE} { abortSignal: approvalAbortController.signal, onStreamingOutput: updateStreamingOutput, + toolContextId: approvalToolContextIdRef.current ?? undefined, }, ); } finally { @@ -9281,6 +9309,7 @@ ${SYSTEM_REMINDER_CLOSE} } } finally { // Always release the execution guard, even if an error occurred + clearApprovalToolContext(); setIsExecutingTool(false); toolAbortControllerRef.current = null; executingToolCallIdsRef.current = []; @@ -9301,6 +9330,7 @@ ${SYSTEM_REMINDER_CLOSE} queueApprovalResults, consumeQueuedMessages, appendTaskNotificationEvents, + clearApprovalToolContext, syncTrajectoryElapsedBase, closeTrajectorySegment, openTrajectorySegment, @@ -9488,7 +9518,10 @@ ${SYSTEM_REMINDER_CLOSE} onChunk(buffersRef.current, chunk); refreshDerived(); }, - { onStreamingOutput: updateStreamingOutput }, + { + onStreamingOutput: updateStreamingOutput, + toolContextId: approvalToolContextIdRef.current ?? undefined, + }, ); // Combine with auto-handled and auto-denied results (from initial check) diff --git a/src/cli/helpers/stream.ts b/src/cli/helpers/stream.ts index 7576ea0..b76f506 100644 --- a/src/cli/helpers/stream.ts +++ b/src/cli/helpers/stream.ts @@ -3,7 +3,7 @@ import type { Stream } from "@letta-ai/letta-client/core/streaming"; import type { LettaStreamingResponse } from "@letta-ai/letta-client/resources/agents/messages"; import type { StopReasonType } from "@letta-ai/letta-client/resources/runs/runs"; import { getClient } from "../../agent/client"; -import { STREAM_REQUEST_START_TIME } from "../../agent/message"; +import { getStreamRequestStartTime } from "../../agent/message"; import { debugWarn } from "../../utils/debug"; import { formatDuration, logTiming } from "../../utils/timing"; @@ -64,11 +64,7 @@ export async function drainStream( contextTracker?: ContextTracker, ): Promise { const startTime = performance.now(); - - // Extract request start time for TTFT logging (attached by sendMessageStream) - const requestStartTime = ( - stream as unknown as Record - )[STREAM_REQUEST_START_TIME]; + const requestStartTime = getStreamRequestStartTime(stream) ?? startTime; let hasLoggedTTFT = false; const streamProcessor = new StreamProcessor(); @@ -146,7 +142,6 @@ export async function drainStream( // Log TTFT (time-to-first-token) when first content chunk arrives if ( !hasLoggedTTFT && - requestStartTime !== undefined && (chunk.message_type === "reasoning_message" || chunk.message_type === "assistant_message") ) { diff --git a/src/headless.ts b/src/headless.ts index 3b942f4..607ecb1 100644 --- a/src/headless.ts +++ b/src/headless.ts @@ -21,7 +21,7 @@ import { getClient } from "./agent/client"; import { setAgentContext, setConversationId } from "./agent/context"; import { createAgent } from "./agent/create"; import { ISOLATED_BLOCK_LABELS } from "./agent/memory"; -import { sendMessageStream } from "./agent/message"; +import { getStreamToolContextId, sendMessageStream } from "./agent/message"; import { getModelUpdateArgs } from "./agent/model"; import { resolveSkillSourcesSelection } from "./agent/skillSources"; import type { SkillSource } from "./agent/skills"; @@ -1465,10 +1465,12 @@ ${SYSTEM_REMINDER_CLOSE} // Wrap sendMessageStream in try-catch to handle pre-stream errors (e.g., 409) let stream: Awaited>; + let turnToolContextId: string | null = null; try { stream = await sendMessageStream(conversationId, currentInput, { agentId: agent.id, }); + turnToolContextId = getStreamToolContextId(stream); } catch (preStreamError) { // Extract error detail using shared helper (handles nested/direct/message shapes) const errorDetail = extractConflictDetail(preStreamError); @@ -1838,7 +1840,13 @@ ${SYSTEM_REMINDER_CLOSE} const { executeApprovalBatch } = await import( "./agent/approval-execution" ); - const executedResults = await executeApprovalBatch(decisions); + const executedResults = await executeApprovalBatch( + decisions, + undefined, + { + toolContextId: turnToolContextId ?? undefined, + }, + ); // Send all results in one batch currentInput = [ @@ -2854,10 +2862,12 @@ async function runBidirectionalMode( // Send message to agent. // Wrap in try-catch to handle pre-stream 409 approval-pending errors. let stream: Awaited>; + let turnToolContextId: string | null = null; try { stream = await sendMessageStream(conversationId, currentInput, { agentId: agent.id, }); + turnToolContextId = getStreamToolContextId(stream); } catch (preStreamError) { // Extract error detail using shared helper (handles nested/direct/message shapes) const errorDetail = extractConflictDetail(preStreamError); @@ -3135,7 +3145,11 @@ async function runBidirectionalMode( const { executeApprovalBatch } = await import( "./agent/approval-execution" ); - const executedResults = await executeApprovalBatch(decisions); + const executedResults = await executeApprovalBatch( + decisions, + undefined, + { toolContextId: turnToolContextId ?? undefined }, + ); // Send approval results back to continue currentInput = [ diff --git a/src/tests/cli/queue-ordering-wiring.test.ts b/src/tests/cli/queue-ordering-wiring.test.ts new file mode 100644 index 0000000..6c653d3 --- /dev/null +++ b/src/tests/cli/queue-ordering-wiring.test.ts @@ -0,0 +1,90 @@ +import { describe, expect, test } from "bun:test"; +import { readFileSync } from "node:fs"; +import { fileURLToPath } from "node:url"; + +function readAppSource(): string { + const appPath = fileURLToPath(new URL("../../cli/App.tsx", import.meta.url)); + return readFileSync(appPath, "utf-8"); +} + +describe("queue ordering wiring", () => { + test("dequeue effect keeps all sensitive safety gates", () => { + const source = readAppSource(); + const start = source.indexOf( + "// Process queued messages when streaming ends", + ); + const end = source.indexOf( + "// Helper to send all approval results when done", + ); + + expect(start).toBeGreaterThan(-1); + expect(end).toBeGreaterThan(start); + + const segment = source.slice(start, end); + expect(segment).toContain("pendingApprovals.length === 0"); + expect(segment).toContain("!commandRunning"); + expect(segment).toContain("!isExecutingTool"); + expect(segment).toContain("!anySelectorOpen"); + expect(segment).toContain("!queuedOverlayAction"); + expect(segment).toContain("!waitingForQueueCancelRef.current"); + expect(segment).toContain("!userCancelledRef.current"); + expect(segment).toContain("!abortControllerRef.current"); + expect(segment).toContain("queuedOverlayAction="); + expect(segment).toContain("setMessageQueue([]);"); + expect(segment).toContain("onSubmitRef.current(concatenatedMessage);"); + expect(segment).toContain("queuedOverlayAction,"); + }); + + test("queued overlay effect only runs when idle and clears action before processing", () => { + const source = readAppSource(); + const start = source.indexOf( + "// Process queued overlay actions when streaming ends", + ); + const end = source.indexOf( + "// Handle escape when profile confirmation is pending", + ); + + expect(start).toBeGreaterThan(-1); + expect(end).toBeGreaterThan(start); + + const segment = source.slice(start, end); + expect(segment).toContain("!streaming"); + expect(segment).toContain("!commandRunning"); + expect(segment).toContain("!isExecutingTool"); + expect(segment).toContain("pendingApprovals.length === 0"); + expect(segment).toContain("queuedOverlayAction !== null"); + expect(segment).toContain("setQueuedOverlayAction(null)"); + expect(segment).toContain('action.type === "switch_model"'); + expect(segment).toContain("handleModelSelect(action.modelId"); + expect(segment).toContain('action.type === "switch_toolset"'); + expect(segment).toContain("handleToolsetSelect(action.toolsetId"); + }); + + test("busy model/toolset handlers enqueue overlay actions", () => { + const source = readAppSource(); + + const modelAnchor = source.indexOf( + "Model switch queued – will switch after current task completes", + ); + expect(modelAnchor).toBeGreaterThan(-1); + const modelWindow = source.slice( + Math.max(0, modelAnchor - 700), + modelAnchor + 700, + ); + expect(modelWindow).toContain("if (isAgentBusy())"); + expect(modelWindow).toContain("setQueuedOverlayAction({"); + expect(modelWindow).toContain('type: "switch_model"'); + + const toolsetAnchor = source.indexOf( + "Toolset switch queued – will switch after current task completes", + ); + expect(toolsetAnchor).toBeGreaterThan(-1); + const toolsetWindow = source.slice( + Math.max(0, toolsetAnchor - 700), + toolsetAnchor + 700, + ); + expect(toolsetWindow).toContain("if (isAgentBusy())"); + expect(toolsetWindow).toContain("setQueuedOverlayAction({"); + expect(toolsetWindow).toContain('type: "switch_toolset"'); + }); +}); diff --git a/src/tests/tools/tool-execution-context.test.ts b/src/tests/tools/tool-execution-context.test.ts new file mode 100644 index 0000000..9bb444d --- /dev/null +++ b/src/tests/tools/tool-execution-context.test.ts @@ -0,0 +1,78 @@ +import { afterAll, beforeAll, describe, expect, test } from "bun:test"; +import { + captureToolExecutionContext, + clearCapturedToolExecutionContexts, + clearExternalTools, + clearTools, + executeTool, + getToolNames, + loadSpecificTools, +} from "../../tools/manager"; + +function asText( + toolReturn: Awaited>["toolReturn"], +) { + return typeof toolReturn === "string" + ? toolReturn + : JSON.stringify(toolReturn); +} + +describe("tool execution context snapshot", () => { + let initialTools: string[] = []; + + beforeAll(() => { + initialTools = getToolNames(); + }); + + afterAll(async () => { + clearCapturedToolExecutionContexts(); + clearExternalTools(); + if (initialTools.length > 0) { + await loadSpecificTools(initialTools); + } else { + clearTools(); + } + }); + + test("executes Read using captured context after global toolset changes", async () => { + await loadSpecificTools(["Read"]); + const { contextId } = captureToolExecutionContext(); + + await loadSpecificTools(["ReadFile"]); + + const withoutContext = await executeTool("Read", { + file_path: "README.md", + }); + expect(withoutContext.status).toBe("error"); + expect(asText(withoutContext.toolReturn)).toContain("Tool not found: Read"); + + const withContext = await executeTool( + "Read", + { file_path: "README.md" }, + { toolContextId: contextId }, + ); + expect(withContext.status).toBe("success"); + }); + + test("executes ReadFile using captured context after global toolset changes", async () => { + await loadSpecificTools(["ReadFile"]); + const { contextId } = captureToolExecutionContext(); + + await loadSpecificTools(["Read"]); + + const withoutContext = await executeTool("ReadFile", { + file_path: "README.md", + }); + expect(withoutContext.status).toBe("error"); + expect(asText(withoutContext.toolReturn)).toContain( + "Tool not found: ReadFile", + ); + + const withContext = await executeTool( + "ReadFile", + { file_path: "README.md" }, + { toolContextId: contextId }, + ); + expect(withContext.status).toBe("success"); + }); +}); diff --git a/src/tools/manager.ts b/src/tools/manager.ts index ff4e67b..7d149cb 100644 --- a/src/tools/manager.ts +++ b/src/tools/manager.ts @@ -237,6 +237,7 @@ type ToolRegistry = Map; // This prevents Bun's bundler from creating duplicate instances const REGISTRY_KEY = Symbol.for("@letta/toolRegistry"); const SWITCH_LOCK_KEY = Symbol.for("@letta/toolSwitchLock"); +const EXECUTION_CONTEXTS_KEY = Symbol.for("@letta/toolExecutionContexts"); interface SwitchLockState { promise: Promise | null; @@ -247,6 +248,7 @@ interface SwitchLockState { type GlobalWithToolState = typeof globalThis & { [REGISTRY_KEY]?: ToolRegistry; [SWITCH_LOCK_KEY]?: SwitchLockState; + [EXECUTION_CONTEXTS_KEY]?: Map; }; function getRegistry(): ToolRegistry { @@ -266,6 +268,57 @@ function getSwitchLock(): SwitchLockState { } const toolRegistry = getRegistry(); +let toolExecutionContextCounter = 0; + +type ToolExecutionContextSnapshot = { + toolRegistry: ToolRegistry; + externalTools: Map; + externalExecutor?: ExternalToolExecutor; +}; + +export type CapturedToolExecutionContext = { + contextId: string; + clientTools: ClientTool[]; +}; + +function getExecutionContexts(): Map { + const global = globalThis as GlobalWithToolState; + if (!global[EXECUTION_CONTEXTS_KEY]) { + global[EXECUTION_CONTEXTS_KEY] = new Map(); + } + return global[EXECUTION_CONTEXTS_KEY]; +} + +function saveExecutionContext(snapshot: ToolExecutionContextSnapshot): string { + const contexts = getExecutionContexts(); + const contextId = `ctx-${Date.now()}-${toolExecutionContextCounter++}`; + contexts.set(contextId, snapshot); + + // Keep memory bounded; stale turns won't need old snapshots. + const MAX_CONTEXTS = 4096; + if (contexts.size > MAX_CONTEXTS) { + const oldestContextId = contexts.keys().next().value; + if (oldestContextId) { + contexts.delete(oldestContextId); + } + } + + return contextId; +} + +function getExecutionContextById( + contextId: string, +): ToolExecutionContextSnapshot | undefined { + return getExecutionContexts().get(contextId); +} + +export function clearCapturedToolExecutionContexts(): void { + getExecutionContexts().clear(); +} + +export function releaseToolExecutionContext(contextId: string): void { + getExecutionContexts().delete(contextId); +} /** * Acquires the toolset switch lock. Call before starting async tool loading. @@ -331,13 +384,16 @@ export function isToolsetSwitchInProgress(): boolean { * - Otherwise, fall back to the alias mapping used for Gemini tools. * - Returns undefined if no matching tool is loaded. */ -function resolveInternalToolName(name: string): string | undefined { - if (toolRegistry.has(name)) { +function resolveInternalToolName( + name: string, + registry: ToolRegistry = toolRegistry, +): string | undefined { + if (registry.has(name)) { return name; } const internalName = getInternalToolName(name); - if (toolRegistry.has(internalName)) { + if (registry.has(internalName)) { return internalName; } @@ -419,6 +475,10 @@ export function setExternalToolExecutor(executor: ExternalToolExecutor): void { (globalThis as GlobalWithExternalTools)[EXTERNAL_EXECUTOR_KEY] = executor; } +function getExternalToolExecutor(): ExternalToolExecutor | undefined { + return (globalThis as GlobalWithExternalTools)[EXTERNAL_EXECUTOR_KEY]; +} + /** * Clear external tools (for testing or session cleanup) */ @@ -461,10 +521,9 @@ export async function executeExternalTool( toolCallId: string, toolName: string, input: Record, + executorOverride?: ExternalToolExecutor, ): Promise { - const executor = (globalThis as GlobalWithExternalTools)[ - EXTERNAL_EXECUTOR_KEY - ]; + const executor = executorOverride ?? getExternalToolExecutor(); if (!executor) { return { toolReturn: `External tool executor not set for tool: ${toolName}`, @@ -518,6 +577,40 @@ export function getClientToolsFromRegistry(): ClientTool[] { return [...builtInTools, ...externalTools]; } +/** + * Capture a turn-scoped tool snapshot and corresponding client_tools payload. + * The returned context id can be used later to execute tool calls against this + * exact snapshot even if the global registry changes between dispatch and execute. + */ +export function captureToolExecutionContext(): CapturedToolExecutionContext { + const snapshot: ToolExecutionContextSnapshot = { + toolRegistry: new Map(toolRegistry), + externalTools: new Map(getExternalToolsRegistry()), + externalExecutor: getExternalToolExecutor(), + }; + const contextId = saveExecutionContext(snapshot); + + const builtInTools = Array.from(snapshot.toolRegistry.entries()).map( + ([name, tool]) => ({ + name: getServerToolName(name), + description: tool.schema.description, + parameters: tool.schema.input_schema, + }), + ); + const externalTools = Array.from(snapshot.externalTools.values()).map( + (tool) => ({ + name: tool.name, + description: tool.description, + parameters: tool.parameters, + }), + ); + + return { + contextId, + clientTools: [...builtInTools, ...externalTools], + }; +} + /** * Get permissions for a specific tool. * @param toolName - The name of the tool @@ -1030,29 +1123,46 @@ export async function executeTool( signal?: AbortSignal; toolCallId?: string; onOutput?: (chunk: string, stream: "stdout" | "stderr") => void; + toolContextId?: string; }, ): Promise { + const context = options?.toolContextId + ? getExecutionContextById(options.toolContextId) + : undefined; + if (options?.toolContextId && !context) { + return { + toolReturn: `Tool execution context not found: ${options.toolContextId}`, + status: "error", + }; + } + const activeRegistry = context?.toolRegistry ?? toolRegistry; + const activeExternalTools = + context?.externalTools ?? getExternalToolsRegistry(); + const activeExternalExecutor = + context?.externalExecutor ?? getExternalToolExecutor(); + // Check if this is an external tool (SDK-executed) - if (isExternalTool(name)) { + if (activeExternalTools.has(name)) { return executeExternalTool( options?.toolCallId ?? `ext-${Date.now()}`, name, args as Record, + activeExternalExecutor, ); } - const internalName = resolveInternalToolName(name); + const internalName = resolveInternalToolName(name, activeRegistry); if (!internalName) { return { - toolReturn: `Tool not found: ${name}. Available tools: ${Array.from(toolRegistry.keys()).join(", ")}`, + toolReturn: `Tool not found: ${name}. Available tools: ${Array.from(activeRegistry.keys()).join(", ")}`, status: "error", }; } - const tool = toolRegistry.get(internalName); + const tool = activeRegistry.get(internalName); if (!tool) { return { - toolReturn: `Tool not found: ${name}. Available tools: ${Array.from(toolRegistry.keys()).join(", ")}`, + toolReturn: `Tool not found: ${name}. Available tools: ${Array.from(activeRegistry.keys()).join(", ")}`, status: "error", }; }