fix: execute tools against dispatch-time snapshot (#1018)
This commit is contained in:
@@ -189,6 +189,7 @@ async function executeSingleDecision(
|
||||
chunk: string,
|
||||
isStderr?: boolean,
|
||||
) => void;
|
||||
toolContextId?: string;
|
||||
},
|
||||
): Promise<ApprovalResult> {
|
||||
// If aborted, record an interrupted result
|
||||
@@ -245,6 +246,7 @@ async function executeSingleDecision(
|
||||
{
|
||||
signal: options?.abortSignal,
|
||||
toolCallId: decision.approval.toolCallId,
|
||||
toolContextId: options?.toolContextId,
|
||||
onOutput: options?.onStreamingOutput
|
||||
? (chunk, stream) =>
|
||||
options.onStreamingOutput?.(
|
||||
@@ -357,6 +359,7 @@ export async function executeApprovalBatch(
|
||||
chunk: string,
|
||||
isStderr?: boolean,
|
||||
) => void;
|
||||
toolContextId?: string;
|
||||
},
|
||||
): Promise<ApprovalResult[]> {
|
||||
// Pre-allocate results array to maintain original order
|
||||
@@ -452,6 +455,7 @@ export async function executeAutoAllowedTools(
|
||||
chunk: string,
|
||||
isStderr?: boolean,
|
||||
) => void;
|
||||
toolContextId?: string;
|
||||
},
|
||||
): Promise<AutoAllowedResult[]> {
|
||||
const decisions: ApprovalDecision[] = autoAllowed.map((ac) => ({
|
||||
|
||||
@@ -9,14 +9,26 @@ import type {
|
||||
LettaStreamingResponse,
|
||||
} from "@letta-ai/letta-client/resources/agents/messages";
|
||||
import {
|
||||
getClientToolsFromRegistry,
|
||||
captureToolExecutionContext,
|
||||
waitForToolsetReady,
|
||||
} from "../tools/manager";
|
||||
import { isTimingsEnabled } from "../utils/timing";
|
||||
import { getClient } from "./client";
|
||||
|
||||
// Symbol to store timing info on the stream object
|
||||
export const STREAM_REQUEST_START_TIME = Symbol("streamRequestStartTime");
|
||||
const streamRequestStartTimes = new WeakMap<object, number>();
|
||||
const streamToolContextIds = new WeakMap<object, string>();
|
||||
|
||||
export function getStreamRequestStartTime(
|
||||
stream: Stream<LettaStreamingResponse>,
|
||||
): number | undefined {
|
||||
return streamRequestStartTimes.get(stream as object);
|
||||
}
|
||||
|
||||
export function getStreamToolContextId(
|
||||
stream: Stream<LettaStreamingResponse>,
|
||||
): string | null {
|
||||
return streamToolContextIds.get(stream as object) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to a conversation and return a streaming response.
|
||||
@@ -40,14 +52,13 @@ export async function sendMessageStream(
|
||||
// requestOptions: { maxRetries?: number } = { maxRetries: 0 },
|
||||
requestOptions: { maxRetries?: number } = {},
|
||||
): Promise<Stream<LettaStreamingResponse>> {
|
||||
// Capture request start time for TTFT measurement when timings are enabled
|
||||
const requestStartTime = isTimingsEnabled() ? performance.now() : undefined;
|
||||
|
||||
const client = await getClient();
|
||||
|
||||
// Wait for any in-progress toolset switch to complete before reading tools
|
||||
// This prevents sending messages with stale tools during a switch
|
||||
await waitForToolsetReady();
|
||||
const { clientTools, contextId } = captureToolExecutionContext();
|
||||
|
||||
let stream: Stream<LettaStreamingResponse>;
|
||||
|
||||
@@ -71,7 +82,7 @@ export async function sendMessageStream(
|
||||
streaming: true,
|
||||
stream_tokens: opts.streamTokens ?? true,
|
||||
background: opts.background ?? true,
|
||||
client_tools: getClientToolsFromRegistry(),
|
||||
client_tools: clientTools,
|
||||
include_compaction_messages: true,
|
||||
},
|
||||
requestOptions,
|
||||
@@ -85,18 +96,17 @@ export async function sendMessageStream(
|
||||
streaming: true,
|
||||
stream_tokens: opts.streamTokens ?? true,
|
||||
background: opts.background ?? true,
|
||||
client_tools: getClientToolsFromRegistry(),
|
||||
client_tools: clientTools,
|
||||
include_compaction_messages: true,
|
||||
},
|
||||
requestOptions,
|
||||
);
|
||||
}
|
||||
|
||||
// Attach start time to stream for TTFT calculation in drainStream
|
||||
if (requestStartTime !== undefined) {
|
||||
(stream as unknown as Record<symbol, number>)[STREAM_REQUEST_START_TIME] =
|
||||
requestStartTime;
|
||||
streamRequestStartTimes.set(stream as object, requestStartTime);
|
||||
}
|
||||
streamToolContextIds.set(stream as object, contextId);
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ import {
|
||||
ensureMemoryFilesystemDirs,
|
||||
getMemoryFilesystemRoot,
|
||||
} from "../agent/memoryFilesystem";
|
||||
import { sendMessageStream } from "../agent/message";
|
||||
import { getStreamToolContextId, sendMessageStream } from "../agent/message";
|
||||
import {
|
||||
getModelInfo,
|
||||
getModelShortName,
|
||||
@@ -95,6 +95,7 @@ import {
|
||||
analyzeToolApproval,
|
||||
checkToolPermission,
|
||||
executeTool,
|
||||
releaseToolExecutionContext,
|
||||
savePermissionRule,
|
||||
type ToolExecutionResult,
|
||||
} from "../tools/manager";
|
||||
@@ -1570,6 +1571,13 @@ export default function App({
|
||||
const lastSentInputRef = useRef<Array<MessageCreate | ApprovalCreate> | null>(
|
||||
null,
|
||||
);
|
||||
const approvalToolContextIdRef = useRef<string | null>(null);
|
||||
const clearApprovalToolContext = useCallback(() => {
|
||||
const contextId = approvalToolContextIdRef.current;
|
||||
if (!contextId) return;
|
||||
approvalToolContextIdRef.current = null;
|
||||
releaseToolExecutionContext(contextId);
|
||||
}, []);
|
||||
// Non-null only when the previous turn was explicitly interrupted by the user.
|
||||
// Used to gate recovery alert injection to true user-interrupt retries.
|
||||
const pendingInterruptRecoveryConversationIdRef = useRef<string | null>(null);
|
||||
@@ -3173,12 +3181,14 @@ export default function App({
|
||||
// throws before streaming begins, e.g., retry after LLM error when backend
|
||||
// already cleared the approval)
|
||||
let stream: Awaited<ReturnType<typeof sendMessageStream>>;
|
||||
let turnToolContextId: string | null = null;
|
||||
try {
|
||||
stream = await sendMessageStream(
|
||||
conversationIdRef.current,
|
||||
currentInput,
|
||||
{ agentId: agentIdRef.current },
|
||||
);
|
||||
turnToolContextId = getStreamToolContextId(stream);
|
||||
} catch (preStreamError) {
|
||||
// Extract error detail using shared helper (handles nested/direct/message shapes)
|
||||
const errorDetail = extractConflictDetail(preStreamError);
|
||||
@@ -3599,6 +3609,7 @@ export default function App({
|
||||
|
||||
// Case 1: Turn ended normally
|
||||
if (stopReasonToHandle === "end_turn") {
|
||||
clearApprovalToolContext();
|
||||
setStreaming(false);
|
||||
const liveElapsedMs = (() => {
|
||||
const snapshot = sessionStatsRef.current.getTrajectorySnapshot();
|
||||
@@ -3775,6 +3786,7 @@ export default function App({
|
||||
|
||||
// Case 1.5: Stream was cancelled by user
|
||||
if (stopReasonToHandle === "cancelled") {
|
||||
clearApprovalToolContext();
|
||||
setStreaming(false);
|
||||
closeTrajectorySegment();
|
||||
syncTrajectoryElapsedBase();
|
||||
@@ -3824,6 +3836,8 @@ export default function App({
|
||||
|
||||
// Case 2: Requires approval
|
||||
if (stopReasonToHandle === "requires_approval") {
|
||||
clearApprovalToolContext();
|
||||
approvalToolContextIdRef.current = turnToolContextId;
|
||||
// Clear stale state immediately to prevent ID mismatch bugs
|
||||
setAutoHandledResults([]);
|
||||
setAutoDeniedApprovals([]);
|
||||
@@ -3839,6 +3853,7 @@ export default function App({
|
||||
: [];
|
||||
|
||||
if (approvalsToProcess.length === 0) {
|
||||
clearApprovalToolContext();
|
||||
appendError(
|
||||
`Unexpected empty approvals with stop reason: ${stopReason}`,
|
||||
);
|
||||
@@ -3851,6 +3866,7 @@ export default function App({
|
||||
// If in quietCancel mode (user queued messages), auto-reject all approvals
|
||||
// and send denials + queued messages together
|
||||
if (waitingForQueueCancelRef.current) {
|
||||
clearApprovalToolContext();
|
||||
// Create denial results for all approvals
|
||||
const denialResults = approvalsToProcess.map((approvalItem) => ({
|
||||
type: "approval" as const,
|
||||
@@ -3898,6 +3914,7 @@ export default function App({
|
||||
userCancelledRef.current ||
|
||||
abortControllerRef.current?.signal.aborted
|
||||
) {
|
||||
clearApprovalToolContext();
|
||||
setStreaming(false);
|
||||
closeTrajectorySegment();
|
||||
syncTrajectoryElapsedBase();
|
||||
@@ -4034,6 +4051,8 @@ export default function App({
|
||||
{
|
||||
abortSignal: autoAllowedAbortController.signal,
|
||||
onStreamingOutput: updateStreamingOutput,
|
||||
toolContextId:
|
||||
approvalToolContextIdRef.current ?? undefined,
|
||||
},
|
||||
)
|
||||
: [];
|
||||
@@ -4744,6 +4763,7 @@ export default function App({
|
||||
consumeQueuedMessages,
|
||||
appendTaskNotificationEvents,
|
||||
maybeCheckMemoryGitStatus,
|
||||
clearApprovalToolContext,
|
||||
openTrajectorySegment,
|
||||
syncTrajectoryTokenBase,
|
||||
syncTrajectoryElapsedBase,
|
||||
@@ -5550,6 +5570,7 @@ export default function App({
|
||||
{
|
||||
abortSignal: autoAllowedAbortController.signal,
|
||||
onStreamingOutput: updateStreamingOutput,
|
||||
toolContextId: approvalToolContextIdRef.current ?? undefined,
|
||||
},
|
||||
);
|
||||
// Map to ApprovalResult format (ToolReturn)
|
||||
@@ -8572,6 +8593,8 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
{
|
||||
abortSignal: autoAllowedAbortController.signal,
|
||||
onStreamingOutput: updateStreamingOutput,
|
||||
toolContextId:
|
||||
approvalToolContextIdRef.current ?? undefined,
|
||||
},
|
||||
)
|
||||
: [];
|
||||
@@ -8816,6 +8839,8 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
{
|
||||
abortSignal: autoAllowedAbortController.signal,
|
||||
onStreamingOutput: updateStreamingOutput,
|
||||
toolContextId:
|
||||
approvalToolContextIdRef.current ?? undefined,
|
||||
},
|
||||
)
|
||||
: [];
|
||||
@@ -9002,6 +9027,7 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
if (
|
||||
!streaming &&
|
||||
hasAnythingQueued &&
|
||||
!queuedOverlayAction && // Prioritize queued model/toolset/system switches before dequeuing messages
|
||||
pendingApprovals.length === 0 &&
|
||||
!commandRunning &&
|
||||
!isExecutingTool &&
|
||||
@@ -9035,7 +9061,7 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
// Log why dequeue was blocked (useful for debugging stuck queues)
|
||||
debugLog(
|
||||
"queue",
|
||||
`Dequeue blocked: streaming=${streaming}, pendingApprovals=${pendingApprovals.length}, commandRunning=${commandRunning}, isExecutingTool=${isExecutingTool}, anySelectorOpen=${anySelectorOpen}, waitingForQueueCancel=${waitingForQueueCancelRef.current}, userCancelled=${userCancelledRef.current}, abortController=${!!abortControllerRef.current}`,
|
||||
`Dequeue blocked: streaming=${streaming}, queuedOverlayAction=${!!queuedOverlayAction}, pendingApprovals=${pendingApprovals.length}, commandRunning=${commandRunning}, isExecutingTool=${isExecutingTool}, anySelectorOpen=${anySelectorOpen}, waitingForQueueCancel=${waitingForQueueCancelRef.current}, userCancelled=${userCancelledRef.current}, abortController=${!!abortControllerRef.current}`,
|
||||
);
|
||||
}
|
||||
}, [
|
||||
@@ -9045,6 +9071,7 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
commandRunning,
|
||||
isExecutingTool,
|
||||
anySelectorOpen,
|
||||
queuedOverlayAction,
|
||||
dequeueEpoch, // Triggered when userCancelledRef is reset OR task notifications added
|
||||
]);
|
||||
|
||||
@@ -9155,6 +9182,7 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
{
|
||||
abortSignal: approvalAbortController.signal,
|
||||
onStreamingOutput: updateStreamingOutput,
|
||||
toolContextId: approvalToolContextIdRef.current ?? undefined,
|
||||
},
|
||||
);
|
||||
} finally {
|
||||
@@ -9281,6 +9309,7 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
}
|
||||
} finally {
|
||||
// Always release the execution guard, even if an error occurred
|
||||
clearApprovalToolContext();
|
||||
setIsExecutingTool(false);
|
||||
toolAbortControllerRef.current = null;
|
||||
executingToolCallIdsRef.current = [];
|
||||
@@ -9301,6 +9330,7 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
queueApprovalResults,
|
||||
consumeQueuedMessages,
|
||||
appendTaskNotificationEvents,
|
||||
clearApprovalToolContext,
|
||||
syncTrajectoryElapsedBase,
|
||||
closeTrajectorySegment,
|
||||
openTrajectorySegment,
|
||||
@@ -9488,7 +9518,10 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
onChunk(buffersRef.current, chunk);
|
||||
refreshDerived();
|
||||
},
|
||||
{ onStreamingOutput: updateStreamingOutput },
|
||||
{
|
||||
onStreamingOutput: updateStreamingOutput,
|
||||
toolContextId: approvalToolContextIdRef.current ?? undefined,
|
||||
},
|
||||
);
|
||||
|
||||
// Combine with auto-handled and auto-denied results (from initial check)
|
||||
|
||||
@@ -3,7 +3,7 @@ import type { Stream } from "@letta-ai/letta-client/core/streaming";
|
||||
import type { LettaStreamingResponse } from "@letta-ai/letta-client/resources/agents/messages";
|
||||
import type { StopReasonType } from "@letta-ai/letta-client/resources/runs/runs";
|
||||
import { getClient } from "../../agent/client";
|
||||
import { STREAM_REQUEST_START_TIME } from "../../agent/message";
|
||||
import { getStreamRequestStartTime } from "../../agent/message";
|
||||
import { debugWarn } from "../../utils/debug";
|
||||
import { formatDuration, logTiming } from "../../utils/timing";
|
||||
|
||||
@@ -64,11 +64,7 @@ export async function drainStream(
|
||||
contextTracker?: ContextTracker,
|
||||
): Promise<DrainResult> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Extract request start time for TTFT logging (attached by sendMessageStream)
|
||||
const requestStartTime = (
|
||||
stream as unknown as Record<symbol, number | undefined>
|
||||
)[STREAM_REQUEST_START_TIME];
|
||||
const requestStartTime = getStreamRequestStartTime(stream) ?? startTime;
|
||||
let hasLoggedTTFT = false;
|
||||
|
||||
const streamProcessor = new StreamProcessor();
|
||||
@@ -146,7 +142,6 @@ export async function drainStream(
|
||||
// Log TTFT (time-to-first-token) when first content chunk arrives
|
||||
if (
|
||||
!hasLoggedTTFT &&
|
||||
requestStartTime !== undefined &&
|
||||
(chunk.message_type === "reasoning_message" ||
|
||||
chunk.message_type === "assistant_message")
|
||||
) {
|
||||
|
||||
@@ -21,7 +21,7 @@ import { getClient } from "./agent/client";
|
||||
import { setAgentContext, setConversationId } from "./agent/context";
|
||||
import { createAgent } from "./agent/create";
|
||||
import { ISOLATED_BLOCK_LABELS } from "./agent/memory";
|
||||
import { sendMessageStream } from "./agent/message";
|
||||
import { getStreamToolContextId, sendMessageStream } from "./agent/message";
|
||||
import { getModelUpdateArgs } from "./agent/model";
|
||||
import { resolveSkillSourcesSelection } from "./agent/skillSources";
|
||||
import type { SkillSource } from "./agent/skills";
|
||||
@@ -1465,10 +1465,12 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
|
||||
// Wrap sendMessageStream in try-catch to handle pre-stream errors (e.g., 409)
|
||||
let stream: Awaited<ReturnType<typeof sendMessageStream>>;
|
||||
let turnToolContextId: string | null = null;
|
||||
try {
|
||||
stream = await sendMessageStream(conversationId, currentInput, {
|
||||
agentId: agent.id,
|
||||
});
|
||||
turnToolContextId = getStreamToolContextId(stream);
|
||||
} catch (preStreamError) {
|
||||
// Extract error detail using shared helper (handles nested/direct/message shapes)
|
||||
const errorDetail = extractConflictDetail(preStreamError);
|
||||
@@ -1838,7 +1840,13 @@ ${SYSTEM_REMINDER_CLOSE}
|
||||
const { executeApprovalBatch } = await import(
|
||||
"./agent/approval-execution"
|
||||
);
|
||||
const executedResults = await executeApprovalBatch(decisions);
|
||||
const executedResults = await executeApprovalBatch(
|
||||
decisions,
|
||||
undefined,
|
||||
{
|
||||
toolContextId: turnToolContextId ?? undefined,
|
||||
},
|
||||
);
|
||||
|
||||
// Send all results in one batch
|
||||
currentInput = [
|
||||
@@ -2854,10 +2862,12 @@ async function runBidirectionalMode(
|
||||
// Send message to agent.
|
||||
// Wrap in try-catch to handle pre-stream 409 approval-pending errors.
|
||||
let stream: Awaited<ReturnType<typeof sendMessageStream>>;
|
||||
let turnToolContextId: string | null = null;
|
||||
try {
|
||||
stream = await sendMessageStream(conversationId, currentInput, {
|
||||
agentId: agent.id,
|
||||
});
|
||||
turnToolContextId = getStreamToolContextId(stream);
|
||||
} catch (preStreamError) {
|
||||
// Extract error detail using shared helper (handles nested/direct/message shapes)
|
||||
const errorDetail = extractConflictDetail(preStreamError);
|
||||
@@ -3135,7 +3145,11 @@ async function runBidirectionalMode(
|
||||
const { executeApprovalBatch } = await import(
|
||||
"./agent/approval-execution"
|
||||
);
|
||||
const executedResults = await executeApprovalBatch(decisions);
|
||||
const executedResults = await executeApprovalBatch(
|
||||
decisions,
|
||||
undefined,
|
||||
{ toolContextId: turnToolContextId ?? undefined },
|
||||
);
|
||||
|
||||
// Send approval results back to continue
|
||||
currentInput = [
|
||||
|
||||
90
src/tests/cli/queue-ordering-wiring.test.ts
Normal file
90
src/tests/cli/queue-ordering-wiring.test.ts
Normal file
@@ -0,0 +1,90 @@
|
||||
import { describe, expect, test } from "bun:test";
|
||||
import { readFileSync } from "node:fs";
|
||||
import { fileURLToPath } from "node:url";
|
||||
|
||||
function readAppSource(): string {
|
||||
const appPath = fileURLToPath(new URL("../../cli/App.tsx", import.meta.url));
|
||||
return readFileSync(appPath, "utf-8");
|
||||
}
|
||||
|
||||
describe("queue ordering wiring", () => {
|
||||
test("dequeue effect keeps all sensitive safety gates", () => {
|
||||
const source = readAppSource();
|
||||
const start = source.indexOf(
|
||||
"// Process queued messages when streaming ends",
|
||||
);
|
||||
const end = source.indexOf(
|
||||
"// Helper to send all approval results when done",
|
||||
);
|
||||
|
||||
expect(start).toBeGreaterThan(-1);
|
||||
expect(end).toBeGreaterThan(start);
|
||||
|
||||
const segment = source.slice(start, end);
|
||||
expect(segment).toContain("pendingApprovals.length === 0");
|
||||
expect(segment).toContain("!commandRunning");
|
||||
expect(segment).toContain("!isExecutingTool");
|
||||
expect(segment).toContain("!anySelectorOpen");
|
||||
expect(segment).toContain("!queuedOverlayAction");
|
||||
expect(segment).toContain("!waitingForQueueCancelRef.current");
|
||||
expect(segment).toContain("!userCancelledRef.current");
|
||||
expect(segment).toContain("!abortControllerRef.current");
|
||||
expect(segment).toContain("queuedOverlayAction=");
|
||||
expect(segment).toContain("setMessageQueue([]);");
|
||||
expect(segment).toContain("onSubmitRef.current(concatenatedMessage);");
|
||||
expect(segment).toContain("queuedOverlayAction,");
|
||||
});
|
||||
|
||||
test("queued overlay effect only runs when idle and clears action before processing", () => {
|
||||
const source = readAppSource();
|
||||
const start = source.indexOf(
|
||||
"// Process queued overlay actions when streaming ends",
|
||||
);
|
||||
const end = source.indexOf(
|
||||
"// Handle escape when profile confirmation is pending",
|
||||
);
|
||||
|
||||
expect(start).toBeGreaterThan(-1);
|
||||
expect(end).toBeGreaterThan(start);
|
||||
|
||||
const segment = source.slice(start, end);
|
||||
expect(segment).toContain("!streaming");
|
||||
expect(segment).toContain("!commandRunning");
|
||||
expect(segment).toContain("!isExecutingTool");
|
||||
expect(segment).toContain("pendingApprovals.length === 0");
|
||||
expect(segment).toContain("queuedOverlayAction !== null");
|
||||
expect(segment).toContain("setQueuedOverlayAction(null)");
|
||||
expect(segment).toContain('action.type === "switch_model"');
|
||||
expect(segment).toContain("handleModelSelect(action.modelId");
|
||||
expect(segment).toContain('action.type === "switch_toolset"');
|
||||
expect(segment).toContain("handleToolsetSelect(action.toolsetId");
|
||||
});
|
||||
|
||||
test("busy model/toolset handlers enqueue overlay actions", () => {
|
||||
const source = readAppSource();
|
||||
|
||||
const modelAnchor = source.indexOf(
|
||||
"Model switch queued – will switch after current task completes",
|
||||
);
|
||||
expect(modelAnchor).toBeGreaterThan(-1);
|
||||
const modelWindow = source.slice(
|
||||
Math.max(0, modelAnchor - 700),
|
||||
modelAnchor + 700,
|
||||
);
|
||||
expect(modelWindow).toContain("if (isAgentBusy())");
|
||||
expect(modelWindow).toContain("setQueuedOverlayAction({");
|
||||
expect(modelWindow).toContain('type: "switch_model"');
|
||||
|
||||
const toolsetAnchor = source.indexOf(
|
||||
"Toolset switch queued – will switch after current task completes",
|
||||
);
|
||||
expect(toolsetAnchor).toBeGreaterThan(-1);
|
||||
const toolsetWindow = source.slice(
|
||||
Math.max(0, toolsetAnchor - 700),
|
||||
toolsetAnchor + 700,
|
||||
);
|
||||
expect(toolsetWindow).toContain("if (isAgentBusy())");
|
||||
expect(toolsetWindow).toContain("setQueuedOverlayAction({");
|
||||
expect(toolsetWindow).toContain('type: "switch_toolset"');
|
||||
});
|
||||
});
|
||||
78
src/tests/tools/tool-execution-context.test.ts
Normal file
78
src/tests/tools/tool-execution-context.test.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
import { afterAll, beforeAll, describe, expect, test } from "bun:test";
|
||||
import {
|
||||
captureToolExecutionContext,
|
||||
clearCapturedToolExecutionContexts,
|
||||
clearExternalTools,
|
||||
clearTools,
|
||||
executeTool,
|
||||
getToolNames,
|
||||
loadSpecificTools,
|
||||
} from "../../tools/manager";
|
||||
|
||||
function asText(
|
||||
toolReturn: Awaited<ReturnType<typeof executeTool>>["toolReturn"],
|
||||
) {
|
||||
return typeof toolReturn === "string"
|
||||
? toolReturn
|
||||
: JSON.stringify(toolReturn);
|
||||
}
|
||||
|
||||
describe("tool execution context snapshot", () => {
|
||||
let initialTools: string[] = [];
|
||||
|
||||
beforeAll(() => {
|
||||
initialTools = getToolNames();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
clearCapturedToolExecutionContexts();
|
||||
clearExternalTools();
|
||||
if (initialTools.length > 0) {
|
||||
await loadSpecificTools(initialTools);
|
||||
} else {
|
||||
clearTools();
|
||||
}
|
||||
});
|
||||
|
||||
test("executes Read using captured context after global toolset changes", async () => {
|
||||
await loadSpecificTools(["Read"]);
|
||||
const { contextId } = captureToolExecutionContext();
|
||||
|
||||
await loadSpecificTools(["ReadFile"]);
|
||||
|
||||
const withoutContext = await executeTool("Read", {
|
||||
file_path: "README.md",
|
||||
});
|
||||
expect(withoutContext.status).toBe("error");
|
||||
expect(asText(withoutContext.toolReturn)).toContain("Tool not found: Read");
|
||||
|
||||
const withContext = await executeTool(
|
||||
"Read",
|
||||
{ file_path: "README.md" },
|
||||
{ toolContextId: contextId },
|
||||
);
|
||||
expect(withContext.status).toBe("success");
|
||||
});
|
||||
|
||||
test("executes ReadFile using captured context after global toolset changes", async () => {
|
||||
await loadSpecificTools(["ReadFile"]);
|
||||
const { contextId } = captureToolExecutionContext();
|
||||
|
||||
await loadSpecificTools(["Read"]);
|
||||
|
||||
const withoutContext = await executeTool("ReadFile", {
|
||||
file_path: "README.md",
|
||||
});
|
||||
expect(withoutContext.status).toBe("error");
|
||||
expect(asText(withoutContext.toolReturn)).toContain(
|
||||
"Tool not found: ReadFile",
|
||||
);
|
||||
|
||||
const withContext = await executeTool(
|
||||
"ReadFile",
|
||||
{ file_path: "README.md" },
|
||||
{ toolContextId: contextId },
|
||||
);
|
||||
expect(withContext.status).toBe("success");
|
||||
});
|
||||
});
|
||||
@@ -237,6 +237,7 @@ type ToolRegistry = Map<string, ToolDefinition>;
|
||||
// This prevents Bun's bundler from creating duplicate instances
|
||||
const REGISTRY_KEY = Symbol.for("@letta/toolRegistry");
|
||||
const SWITCH_LOCK_KEY = Symbol.for("@letta/toolSwitchLock");
|
||||
const EXECUTION_CONTEXTS_KEY = Symbol.for("@letta/toolExecutionContexts");
|
||||
|
||||
interface SwitchLockState {
|
||||
promise: Promise<void> | null;
|
||||
@@ -247,6 +248,7 @@ interface SwitchLockState {
|
||||
type GlobalWithToolState = typeof globalThis & {
|
||||
[REGISTRY_KEY]?: ToolRegistry;
|
||||
[SWITCH_LOCK_KEY]?: SwitchLockState;
|
||||
[EXECUTION_CONTEXTS_KEY]?: Map<string, ToolExecutionContextSnapshot>;
|
||||
};
|
||||
|
||||
function getRegistry(): ToolRegistry {
|
||||
@@ -266,6 +268,57 @@ function getSwitchLock(): SwitchLockState {
|
||||
}
|
||||
|
||||
const toolRegistry = getRegistry();
|
||||
let toolExecutionContextCounter = 0;
|
||||
|
||||
type ToolExecutionContextSnapshot = {
|
||||
toolRegistry: ToolRegistry;
|
||||
externalTools: Map<string, ExternalToolDefinition>;
|
||||
externalExecutor?: ExternalToolExecutor;
|
||||
};
|
||||
|
||||
export type CapturedToolExecutionContext = {
|
||||
contextId: string;
|
||||
clientTools: ClientTool[];
|
||||
};
|
||||
|
||||
function getExecutionContexts(): Map<string, ToolExecutionContextSnapshot> {
|
||||
const global = globalThis as GlobalWithToolState;
|
||||
if (!global[EXECUTION_CONTEXTS_KEY]) {
|
||||
global[EXECUTION_CONTEXTS_KEY] = new Map();
|
||||
}
|
||||
return global[EXECUTION_CONTEXTS_KEY];
|
||||
}
|
||||
|
||||
function saveExecutionContext(snapshot: ToolExecutionContextSnapshot): string {
|
||||
const contexts = getExecutionContexts();
|
||||
const contextId = `ctx-${Date.now()}-${toolExecutionContextCounter++}`;
|
||||
contexts.set(contextId, snapshot);
|
||||
|
||||
// Keep memory bounded; stale turns won't need old snapshots.
|
||||
const MAX_CONTEXTS = 4096;
|
||||
if (contexts.size > MAX_CONTEXTS) {
|
||||
const oldestContextId = contexts.keys().next().value;
|
||||
if (oldestContextId) {
|
||||
contexts.delete(oldestContextId);
|
||||
}
|
||||
}
|
||||
|
||||
return contextId;
|
||||
}
|
||||
|
||||
function getExecutionContextById(
|
||||
contextId: string,
|
||||
): ToolExecutionContextSnapshot | undefined {
|
||||
return getExecutionContexts().get(contextId);
|
||||
}
|
||||
|
||||
export function clearCapturedToolExecutionContexts(): void {
|
||||
getExecutionContexts().clear();
|
||||
}
|
||||
|
||||
export function releaseToolExecutionContext(contextId: string): void {
|
||||
getExecutionContexts().delete(contextId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Acquires the toolset switch lock. Call before starting async tool loading.
|
||||
@@ -331,13 +384,16 @@ export function isToolsetSwitchInProgress(): boolean {
|
||||
* - Otherwise, fall back to the alias mapping used for Gemini tools.
|
||||
* - Returns undefined if no matching tool is loaded.
|
||||
*/
|
||||
function resolveInternalToolName(name: string): string | undefined {
|
||||
if (toolRegistry.has(name)) {
|
||||
function resolveInternalToolName(
|
||||
name: string,
|
||||
registry: ToolRegistry = toolRegistry,
|
||||
): string | undefined {
|
||||
if (registry.has(name)) {
|
||||
return name;
|
||||
}
|
||||
|
||||
const internalName = getInternalToolName(name);
|
||||
if (toolRegistry.has(internalName)) {
|
||||
if (registry.has(internalName)) {
|
||||
return internalName;
|
||||
}
|
||||
|
||||
@@ -419,6 +475,10 @@ export function setExternalToolExecutor(executor: ExternalToolExecutor): void {
|
||||
(globalThis as GlobalWithExternalTools)[EXTERNAL_EXECUTOR_KEY] = executor;
|
||||
}
|
||||
|
||||
function getExternalToolExecutor(): ExternalToolExecutor | undefined {
|
||||
return (globalThis as GlobalWithExternalTools)[EXTERNAL_EXECUTOR_KEY];
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear external tools (for testing or session cleanup)
|
||||
*/
|
||||
@@ -461,10 +521,9 @@ export async function executeExternalTool(
|
||||
toolCallId: string,
|
||||
toolName: string,
|
||||
input: Record<string, unknown>,
|
||||
executorOverride?: ExternalToolExecutor,
|
||||
): Promise<ToolExecutionResult> {
|
||||
const executor = (globalThis as GlobalWithExternalTools)[
|
||||
EXTERNAL_EXECUTOR_KEY
|
||||
];
|
||||
const executor = executorOverride ?? getExternalToolExecutor();
|
||||
if (!executor) {
|
||||
return {
|
||||
toolReturn: `External tool executor not set for tool: ${toolName}`,
|
||||
@@ -518,6 +577,40 @@ export function getClientToolsFromRegistry(): ClientTool[] {
|
||||
return [...builtInTools, ...externalTools];
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture a turn-scoped tool snapshot and corresponding client_tools payload.
|
||||
* The returned context id can be used later to execute tool calls against this
|
||||
* exact snapshot even if the global registry changes between dispatch and execute.
|
||||
*/
|
||||
export function captureToolExecutionContext(): CapturedToolExecutionContext {
|
||||
const snapshot: ToolExecutionContextSnapshot = {
|
||||
toolRegistry: new Map(toolRegistry),
|
||||
externalTools: new Map(getExternalToolsRegistry()),
|
||||
externalExecutor: getExternalToolExecutor(),
|
||||
};
|
||||
const contextId = saveExecutionContext(snapshot);
|
||||
|
||||
const builtInTools = Array.from(snapshot.toolRegistry.entries()).map(
|
||||
([name, tool]) => ({
|
||||
name: getServerToolName(name),
|
||||
description: tool.schema.description,
|
||||
parameters: tool.schema.input_schema,
|
||||
}),
|
||||
);
|
||||
const externalTools = Array.from(snapshot.externalTools.values()).map(
|
||||
(tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
}),
|
||||
);
|
||||
|
||||
return {
|
||||
contextId,
|
||||
clientTools: [...builtInTools, ...externalTools],
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get permissions for a specific tool.
|
||||
* @param toolName - The name of the tool
|
||||
@@ -1030,29 +1123,46 @@ export async function executeTool(
|
||||
signal?: AbortSignal;
|
||||
toolCallId?: string;
|
||||
onOutput?: (chunk: string, stream: "stdout" | "stderr") => void;
|
||||
toolContextId?: string;
|
||||
},
|
||||
): Promise<ToolExecutionResult> {
|
||||
const context = options?.toolContextId
|
||||
? getExecutionContextById(options.toolContextId)
|
||||
: undefined;
|
||||
if (options?.toolContextId && !context) {
|
||||
return {
|
||||
toolReturn: `Tool execution context not found: ${options.toolContextId}`,
|
||||
status: "error",
|
||||
};
|
||||
}
|
||||
const activeRegistry = context?.toolRegistry ?? toolRegistry;
|
||||
const activeExternalTools =
|
||||
context?.externalTools ?? getExternalToolsRegistry();
|
||||
const activeExternalExecutor =
|
||||
context?.externalExecutor ?? getExternalToolExecutor();
|
||||
|
||||
// Check if this is an external tool (SDK-executed)
|
||||
if (isExternalTool(name)) {
|
||||
if (activeExternalTools.has(name)) {
|
||||
return executeExternalTool(
|
||||
options?.toolCallId ?? `ext-${Date.now()}`,
|
||||
name,
|
||||
args as Record<string, unknown>,
|
||||
activeExternalExecutor,
|
||||
);
|
||||
}
|
||||
|
||||
const internalName = resolveInternalToolName(name);
|
||||
const internalName = resolveInternalToolName(name, activeRegistry);
|
||||
if (!internalName) {
|
||||
return {
|
||||
toolReturn: `Tool not found: ${name}. Available tools: ${Array.from(toolRegistry.keys()).join(", ")}`,
|
||||
toolReturn: `Tool not found: ${name}. Available tools: ${Array.from(activeRegistry.keys()).join(", ")}`,
|
||||
status: "error",
|
||||
};
|
||||
}
|
||||
|
||||
const tool = toolRegistry.get(internalName);
|
||||
const tool = activeRegistry.get(internalName);
|
||||
if (!tool) {
|
||||
return {
|
||||
toolReturn: `Tool not found: ${name}. Available tools: ${Array.from(toolRegistry.keys()).join(", ")}`,
|
||||
toolReturn: `Tool not found: ${name}. Available tools: ${Array.from(activeRegistry.keys()).join(", ")}`,
|
||||
status: "error",
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user