From 52f2cc9924f512b94dcbcd34dba1d72acf5819c1 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 5 Mar 2026 22:29:08 -0800 Subject: [PATCH] fix(listen): preserve interrupt error status through next-turn persistence (#1294) --- src/agent/approval-result-normalization.ts | 123 ++++++ src/agent/message.ts | 66 +++- .../approval-result-normalization.test.ts | 103 +++++ .../websocket/listen-client-protocol.test.ts | 159 ++++++++ .../websocket/listen-interrupt-queue.test.ts | 160 +++++++- src/websocket/listen-client.ts | 365 ++++++++++++++++-- 6 files changed, 918 insertions(+), 58 deletions(-) create mode 100644 src/agent/approval-result-normalization.ts create mode 100644 src/tests/agent/approval-result-normalization.test.ts diff --git a/src/agent/approval-result-normalization.ts b/src/agent/approval-result-normalization.ts new file mode 100644 index 0000000..52a612a --- /dev/null +++ b/src/agent/approval-result-normalization.ts @@ -0,0 +1,123 @@ +import type { MessageCreate } from "@letta-ai/letta-client/resources/agents/agents"; +import type { ApprovalCreate } from "@letta-ai/letta-client/resources/agents/messages"; +import { INTERRUPTED_BY_USER } from "../constants"; +import type { ApprovalResult } from "./approval-execution"; + +type OutgoingMessage = MessageCreate | ApprovalCreate; + +export type ApprovalNormalizationOptions = { + /** + * Structured interrupt provenance: tool_call_ids known to have been interrupted. + * When provided, these IDs are forced to persist as status=error. + */ + interruptedToolCallIds?: Iterable; + /** + * Temporary fallback guard for legacy drift where tool_return text is the only + * interrupt signal. Keep false by default for strict structured behavior. + */ + allowInterruptTextFallback?: boolean; +}; + +function normalizeToolReturnText(value: unknown): string { + if (typeof value === "string") return value; + + if (Array.isArray(value)) { + const text = value + .filter( + (part): part is { type: "text"; text: string } => + !!part && + typeof part === "object" && + "type" in part && + (part as { type?: unknown }).type === "text" && + "text" in part && + typeof (part as { text?: unknown }).text === "string", + ) + .map((part) => part.text) + .join("\n") + .trim(); + return text; + } + + if (value === null || value === undefined) return ""; + + try { + return JSON.stringify(value); + } catch { + return String(value); + } +} + +export function normalizeApprovalResultsForPersistence( + approvals: ApprovalResult[] | null | undefined, + options: ApprovalNormalizationOptions = {}, +): ApprovalResult[] { + if (!approvals || approvals.length === 0) return approvals ?? []; + + const interruptedSet = new Set(options.interruptedToolCallIds ?? []); + + return approvals.map((approval) => { + if ( + !approval || + typeof approval !== "object" || + !("type" in approval) || + approval.type !== "tool" + ) { + return approval; + } + + const toolCallId = + "tool_call_id" in approval && typeof approval.tool_call_id === "string" + ? approval.tool_call_id + : ""; + + const interruptedByStructuredId = + toolCallId.length > 0 && interruptedSet.has(toolCallId); + const interruptedByLegacyText = options.allowInterruptTextFallback + ? normalizeToolReturnText( + "tool_return" in approval ? approval.tool_return : "", + ) === INTERRUPTED_BY_USER + : false; + + if ( + (interruptedByStructuredId || interruptedByLegacyText) && + "status" in approval && + approval.status !== "error" + ) { + return { + ...approval, + status: "error" as const, + }; + } + + return approval; + }); +} + +export function normalizeOutgoingApprovalMessages( + messages: OutgoingMessage[], + options: ApprovalNormalizationOptions = {}, +): OutgoingMessage[] { + if (!messages || messages.length === 0) return messages; + + return messages.map((message) => { + if ( + !message || + typeof message !== "object" || + !("type" in message) || + message.type !== "approval" || + !("approvals" in message) + ) { + return message; + } + + const normalizedApprovals = normalizeApprovalResultsForPersistence( + message.approvals as ApprovalResult[], + options, + ); + + return { + ...message, + approvals: normalizedApprovals, + } as ApprovalCreate; + }); +} diff --git a/src/agent/message.ts b/src/agent/message.ts index e6e22b9..7f20a06 100644 --- a/src/agent/message.ts +++ b/src/agent/message.ts @@ -9,10 +9,15 @@ import type { LettaStreamingResponse, } from "@letta-ai/letta-client/resources/agents/messages"; import { + type ClientTool, captureToolExecutionContext, waitForToolsetReady, } from "../tools/manager"; import { isTimingsEnabled } from "../utils/timing"; +import { + type ApprovalNormalizationOptions, + normalizeOutgoingApprovalMessages, +} from "./approval-result-normalization"; import { getClient } from "./client"; const streamRequestStartTimes = new WeakMap(); @@ -43,6 +48,40 @@ export function getStreamRequestContext( return streamRequestContexts.get(stream as object); } +export type SendMessageStreamOptions = { + streamTokens?: boolean; + background?: boolean; + agentId?: string; // Required when conversationId is "default" + approvalNormalization?: ApprovalNormalizationOptions; +}; + +export function buildConversationMessagesCreateRequestBody( + conversationId: string, + messages: Array, + opts: SendMessageStreamOptions = { streamTokens: true, background: true }, + clientTools: ClientTool[], +) { + const isDefaultConversation = conversationId === "default"; + if (isDefaultConversation && !opts.agentId) { + throw new Error( + "agentId is required in opts when using default conversation", + ); + } + + return { + messages: normalizeOutgoingApprovalMessages( + messages, + opts.approvalNormalization, + ), + streaming: true, + stream_tokens: opts.streamTokens ?? true, + background: opts.background ?? true, + client_tools: clientTools, + include_compaction_messages: true, + ...(isDefaultConversation ? { agent_id: opts.agentId } : {}), + }; +} + /** * Send a message to a conversation and return a streaming response. * Uses the conversations API for all conversations. @@ -54,11 +93,7 @@ export function getStreamRequestContext( export async function sendMessageStream( conversationId: string, messages: Array, - opts: { - streamTokens?: boolean; - background?: boolean; - agentId?: string; // Required when conversationId is "default" - } = { streamTokens: true, background: true }, + opts: SendMessageStreamOptions = { streamTokens: true, background: true }, // Disable SDK retries by default - state management happens outside the stream, // so retries would violate idempotency and create race conditions requestOptions: { maxRetries?: number; signal?: AbortSignal } = { @@ -74,24 +109,13 @@ export async function sendMessageStream( await waitForToolsetReady(); const { clientTools, contextId } = captureToolExecutionContext(); - const isDefaultConversation = conversationId === "default"; - if (isDefaultConversation && !opts.agentId) { - throw new Error( - "agentId is required in opts when using default conversation", - ); - } - const resolvedConversationId = conversationId; - - const requestBody = { + const requestBody = buildConversationMessagesCreateRequestBody( + conversationId, messages, - streaming: true, - stream_tokens: opts.streamTokens ?? true, - background: opts.background ?? true, - client_tools: clientTools, - include_compaction_messages: true, - ...(isDefaultConversation ? { agent_id: opts.agentId } : {}), - }; + opts, + clientTools, + ); if (process.env.DEBUG) { console.log( diff --git a/src/tests/agent/approval-result-normalization.test.ts b/src/tests/agent/approval-result-normalization.test.ts new file mode 100644 index 0000000..1093ffe --- /dev/null +++ b/src/tests/agent/approval-result-normalization.test.ts @@ -0,0 +1,103 @@ +import { describe, expect, test } from "bun:test"; +import type { ApprovalCreate } from "@letta-ai/letta-client/resources/agents/messages"; +import type { ApprovalResult } from "../../agent/approval-execution"; +import { + normalizeApprovalResultsForPersistence, + normalizeOutgoingApprovalMessages, +} from "../../agent/approval-result-normalization"; +import { INTERRUPTED_BY_USER } from "../../constants"; + +describe("normalizeApprovalResultsForPersistence", () => { + test("forces status=error for structured interrupted tool_call_ids", () => { + const approvals: ApprovalResult[] = [ + { + type: "tool", + tool_call_id: "call-1", + tool_return: "some return", + status: "success", + } as ApprovalResult, + ]; + + const normalized = normalizeApprovalResultsForPersistence(approvals, { + interruptedToolCallIds: ["call-1"], + }); + + expect(normalized[0]).toMatchObject({ + type: "tool", + tool_call_id: "call-1", + status: "error", + }); + }); + + test("does not modify non-interrupted tool results", () => { + const approvals: ApprovalResult[] = [ + { + type: "tool", + tool_call_id: "call-2", + tool_return: "ok", + status: "success", + } as ApprovalResult, + ]; + + const normalized = normalizeApprovalResultsForPersistence(approvals, { + interruptedToolCallIds: ["other-id"], + }); + + expect(normalized[0]).toMatchObject({ + type: "tool", + tool_call_id: "call-2", + status: "success", + }); + }); + + test("supports legacy fallback on interrupt text when explicitly enabled", () => { + const approvals: ApprovalResult[] = [ + { + type: "tool", + tool_call_id: "call-3", + tool_return: [{ type: "text", text: INTERRUPTED_BY_USER }], + status: "success", + } as ApprovalResult, + ]; + + const normalized = normalizeApprovalResultsForPersistence(approvals, { + allowInterruptTextFallback: true, + }); + + expect(normalized[0]).toMatchObject({ + type: "tool", + tool_call_id: "call-3", + status: "error", + }); + }); +}); + +describe("normalizeOutgoingApprovalMessages", () => { + test("normalizes approvals and preserves non-approval messages", () => { + const approvalMessage: ApprovalCreate = { + type: "approval", + approvals: [ + { + type: "tool", + tool_call_id: "call-7", + tool_return: "foo", + status: "success", + } as ApprovalResult, + ], + }; + + const messages = normalizeOutgoingApprovalMessages( + [{ role: "user", content: "hello" }, approvalMessage], + { interruptedToolCallIds: ["call-7"] }, + ); + + expect(messages[0]).toMatchObject({ role: "user", content: "hello" }); + const normalizedApproval = messages[1] as ApprovalCreate; + const approvals = normalizedApproval.approvals ?? []; + expect(approvals[0]).toMatchObject({ + type: "tool", + tool_call_id: "call-7", + status: "error", + }); + }); +}); diff --git a/src/tests/websocket/listen-client-protocol.test.ts b/src/tests/websocket/listen-client-protocol.test.ts index 2677373..462f5cd 100644 --- a/src/tests/websocket/listen-client-protocol.test.ts +++ b/src/tests/websocket/listen-client-protocol.test.ts @@ -1,5 +1,8 @@ import { describe, expect, test } from "bun:test"; +import type { ApprovalCreate } from "@letta-ai/letta-client/resources/agents/messages"; import WebSocket from "ws"; +import { buildConversationMessagesCreateRequestBody } from "../../agent/message"; +import { INTERRUPTED_BY_USER } from "../../constants"; import type { ControlRequest, ControlResponseBody } from "../../types/protocol"; import { __listenClientTestUtils, @@ -640,3 +643,159 @@ describe("listen-client post-stop approval recovery policy", () => { expect(shouldRecover).toBe(false); }); }); + +describe("listen-client interrupt persistence normalization", () => { + test("forces interrupted in-flight tool results to status=error when cancelRequested", () => { + const runtime = __listenClientTestUtils.createRuntime(); + runtime.cancelRequested = true; + + const normalized = + __listenClientTestUtils.normalizeExecutionResultsForInterruptParity( + runtime, + [ + { + type: "tool", + tool_call_id: "tool-1", + tool_return: "Interrupted by user", + status: "success", + }, + ], + ["tool-1"], + ); + + expect(normalized).toEqual([ + { + type: "tool", + tool_call_id: "tool-1", + tool_return: "Interrupted by user", + status: "error", + }, + ]); + }); + + test("leaves tool status unchanged when not in cancel flow", () => { + const runtime = __listenClientTestUtils.createRuntime(); + runtime.cancelRequested = false; + + const normalized = + __listenClientTestUtils.normalizeExecutionResultsForInterruptParity( + runtime, + [ + { + type: "tool", + tool_call_id: "tool-1", + tool_return: "Interrupted by user", + status: "success", + }, + ], + ["tool-1"], + ); + + expect(normalized).toEqual([ + { + type: "tool", + tool_call_id: "tool-1", + tool_return: "Interrupted by user", + status: "success", + }, + ]); + }); +}); + +describe("listen-client interrupt persistence request body", () => { + test("post-interrupt next-turn payload keeps interrupted tool returns as status=error", () => { + const runtime = __listenClientTestUtils.createRuntime(); + const consumedAgentId = "agent-1"; + const consumedConversationId = "default"; + + __listenClientTestUtils.populateInterruptQueue(runtime, { + lastExecutionResults: null, + lastExecutingToolCallIds: ["call-running-1"], + lastNeedsUserInputToolCallIds: [], + agentId: consumedAgentId, + conversationId: consumedConversationId, + }); + + const consumed = __listenClientTestUtils.consumeInterruptQueue( + runtime, + consumedAgentId, + consumedConversationId, + ); + + expect(consumed).not.toBeNull(); + if (!consumed) { + throw new Error("Expected queued interrupt approvals to be consumed"); + } + + const requestBody = buildConversationMessagesCreateRequestBody( + consumedConversationId, + [ + consumed.approvalMessage, + { + type: "message", + role: "user", + content: "next user message after interrupt", + }, + ], + { + agentId: consumedAgentId, + streamTokens: true, + background: true, + approvalNormalization: { + interruptedToolCallIds: consumed.interruptedToolCallIds, + }, + }, + [], + ); + + const approvalMessage = requestBody.messages[0] as ApprovalCreate; + expect(approvalMessage.type).toBe("approval"); + expect(approvalMessage.approvals?.[0]).toMatchObject({ + type: "tool", + tool_call_id: "call-running-1", + tool_return: INTERRUPTED_BY_USER, + status: "error", + }); + }); +}); + +describe("listen-client tool_return wire normalization", () => { + test("normalizes legacy top-level tool return fields to canonical tool_returns[]", () => { + const normalized = __listenClientTestUtils.normalizeToolReturnWireMessage({ + message_type: "tool_return_message", + id: "message-1", + run_id: "run-1", + tool_call_id: "call-1", + status: "error", + tool_return: [{ type: "text", text: "Interrupted by user" }], + }); + + expect(normalized).toEqual({ + message_type: "tool_return_message", + id: "message-1", + run_id: "run-1", + tool_returns: [ + { + tool_call_id: "call-1", + status: "error", + tool_return: "Interrupted by user", + }, + ], + }); + expect(normalized).not.toHaveProperty("tool_call_id"); + expect(normalized).not.toHaveProperty("status"); + expect(normalized).not.toHaveProperty("tool_return"); + }); + + test("returns null for tool_return_message when no canonical status is available", () => { + const normalized = __listenClientTestUtils.normalizeToolReturnWireMessage({ + message_type: "tool_return_message", + id: "message-2", + run_id: "run-2", + tool_call_id: "call-2", + tool_return: "maybe done", + }); + + expect(normalized).toBeNull(); + }); +}); diff --git a/src/tests/websocket/listen-interrupt-queue.test.ts b/src/tests/websocket/listen-interrupt-queue.test.ts index cffc612..4ea1220 100644 --- a/src/tests/websocket/listen-interrupt-queue.test.ts +++ b/src/tests/websocket/listen-interrupt-queue.test.ts @@ -62,12 +62,14 @@ describe("ListenerRuntime interrupt queue fields", () => { const runtime = createRuntime(); expect(runtime.pendingInterruptedResults).toBeNull(); expect(runtime.pendingInterruptedContext).toBeNull(); + expect(runtime.pendingInterruptedToolCallIds).toBeNull(); + expect(runtime.activeExecutingToolCallIds).toEqual([]); expect(runtime.continuationEpoch).toBe(0); }); }); describe("stopRuntime teardown", () => { - test("clears pendingInterruptedResults, context, and batch map", () => { + test("clears pendingInterruptedResults, context, ids, and batch map", () => { const runtime = createRuntime(); runtime.socket = new MockSocket(WebSocket.OPEN) as unknown as WebSocket; @@ -84,12 +86,16 @@ describe("stopRuntime teardown", () => { conversationId: "conv-1", continuationEpoch: 0, }; + runtime.pendingInterruptedToolCallIds = ["call-1"]; + runtime.activeExecutingToolCallIds = ["call-1"]; runtime.pendingApprovalBatchByToolCallId.set("call-1", "batch-1"); stopRuntime(runtime, true); expect(runtime.pendingInterruptedResults).toBeNull(); expect(runtime.pendingInterruptedContext).toBeNull(); + expect(runtime.pendingInterruptedToolCallIds).toBeNull(); + expect(runtime.activeExecutingToolCallIds).toEqual([]); expect(runtime.pendingApprovalBatchByToolCallId.size).toBe(0); }); @@ -178,6 +184,29 @@ describe("extractInterruptToolReturns", () => { ]); }); + test("converts multimodal tool_return content into displayable text", () => { + const results: ApprovalResult[] = [ + { + type: "tool", + tool_call_id: "call-multimodal", + status: "error", + tool_return: [ + { type: "text", text: "Interrupted by user" }, + { type: "image", image_url: "https://example.com/image.png" }, + ], + } as ApprovalResult, + ]; + + const mapped = extractInterruptToolReturns(results); + expect(mapped).toEqual([ + { + tool_call_id: "call-multimodal", + status: "error", + tool_return: "Interrupted by user", + }, + ]); + }); + test("emitInterruptToolReturnMessage emits deterministic per-tool terminal messages", () => { const runtime = createRuntime(); const socket = new MockSocket(WebSocket.OPEN) as unknown as WebSocket; @@ -208,16 +237,26 @@ describe("extractInterruptToolReturns", () => { expect(toolReturnFrames).toHaveLength(2); expect(toolReturnFrames[0]).toMatchObject({ run_id: "run-1", - tool_call_id: "call-a", - status: "success", - tool_returns: [{ tool_call_id: "call-a", status: "success" }], + tool_returns: [ + { tool_call_id: "call-a", status: "success", tool_return: "704" }, + ], }); expect(toolReturnFrames[1]).toMatchObject({ run_id: "run-1", - tool_call_id: "call-b", - status: "error", - tool_returns: [{ tool_call_id: "call-b", status: "error" }], + tool_returns: [ + { + tool_call_id: "call-b", + status: "error", + tool_return: "User interrupted the stream", + }, + ], }); + expect(toolReturnFrames[0]).not.toHaveProperty("tool_call_id"); + expect(toolReturnFrames[0]).not.toHaveProperty("status"); + expect(toolReturnFrames[0]).not.toHaveProperty("tool_return"); + expect(toolReturnFrames[1]).not.toHaveProperty("tool_call_id"); + expect(toolReturnFrames[1]).not.toHaveProperty("status"); + expect(toolReturnFrames[1]).not.toHaveProperty("tool_return"); }); }); @@ -305,13 +344,14 @@ describe("Path A: cancel during tool execution → next turn consumes actual res // Cancel fires: populateInterruptQueue (Path A — has execution results) const populated = populateInterruptQueue(runtime, { lastExecutionResults: executionResults, + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: ["call-1", "call-2"], agentId, conversationId, }); expect(populated).toBe(true); - expect(runtime.pendingInterruptedResults).toBe(executionResults); + expect(runtime.pendingInterruptedResults).toEqual(executionResults); expect(runtime.pendingInterruptedContext).toMatchObject({ agentId, conversationId, @@ -322,9 +362,10 @@ describe("Path A: cancel during tool execution → next turn consumes actual res const consumed = consumeInterruptQueue(runtime, agentId, conversationId); expect(consumed).not.toBeNull(); - expect(consumed?.type).toBe("approval"); - expect(consumed?.approvals).toBe(executionResults); - expect(consumed?.approvals).toHaveLength(2); + expect(consumed?.approvalMessage.type).toBe("approval"); + expect(consumed?.approvalMessage.approvals).toEqual(executionResults); + expect(consumed?.approvalMessage.approvals).toHaveLength(2); + expect(consumed?.interruptedToolCallIds).toEqual([]); // Queue is atomically cleared after consumption expect(runtime.pendingInterruptedResults).toBeNull(); @@ -342,6 +383,7 @@ describe("Path A: cancel during tool execution → next turn consumes actual res const populated = populateInterruptQueue(runtime, { lastExecutionResults: executionResults, + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: ["call-1"], agentId: "agent-1", conversationId: "conv-1", @@ -353,9 +395,87 @@ describe("Path A: cancel during tool execution → next turn consumes actual res approve: true, // Path A preserves actual approval state }); }); + + test("normalizes interrupted tool results to error via structured tool_call_id", () => { + const runtime = createRuntime(); + const executionResults: ApprovalResult[] = [ + { + type: "tool", + tool_call_id: "call-1", + status: "success", + tool_return: "result text does not matter when ID is interrupted", + } as unknown as ApprovalResult, + ]; + + const populated = populateInterruptQueue(runtime, { + lastExecutionResults: executionResults, + lastExecutingToolCallIds: ["call-1"], + lastNeedsUserInputToolCallIds: [], + agentId: "agent-1", + conversationId: "conv-1", + }); + + expect(populated).toBe(true); + expect(runtime.pendingInterruptedResults?.[0]).toMatchObject({ + type: "tool", + tool_call_id: "call-1", + status: "error", + }); + expect(runtime.pendingInterruptedToolCallIds).toEqual(["call-1"]); + }); + + test("keeps legacy text fallback for interrupted tool return normalization", () => { + const runtime = createRuntime(); + const executionResults: ApprovalResult[] = [ + { + type: "tool", + tool_call_id: "call-legacy", + status: "success", + tool_return: [{ type: "text", text: "Interrupted by user" }], + } as unknown as ApprovalResult, + ]; + + const populated = populateInterruptQueue(runtime, { + lastExecutionResults: executionResults, + lastExecutingToolCallIds: [], + lastNeedsUserInputToolCallIds: [], + agentId: "agent-1", + conversationId: "conv-1", + }); + + expect(populated).toBe(true); + expect(runtime.pendingInterruptedResults?.[0]).toMatchObject({ + type: "tool", + tool_call_id: "call-legacy", + status: "error", + }); + }); }); describe("Path B: cancel during approval wait → next turn consumes synthesized denials", () => { + test("prefers synthesized tool-error results when execution was already in-flight", () => { + const runtime = createRuntime(); + + const populated = populateInterruptQueue(runtime, { + lastExecutionResults: null, + lastExecutingToolCallIds: ["call-running-1"], + lastNeedsUserInputToolCallIds: ["call-running-1"], + agentId: "agent-1", + conversationId: "conv-1", + }); + + expect(populated).toBe(true); + expect(runtime.pendingInterruptedResults).toEqual([ + { + type: "tool", + tool_call_id: "call-running-1", + tool_return: "Interrupted by user", + status: "error", + }, + ]); + expect(runtime.pendingInterruptedToolCallIds).toEqual(["call-running-1"]); + }); + test("full sequence: populate from batch map IDs → consume synthesized denials", () => { const runtime = createRuntime(); const agentId = "agent-abc"; @@ -371,6 +491,7 @@ describe("Path B: cancel during approval wait → next turn consumes synthesized // Cancel fires during approval wait: no execution results const populated = populateInterruptQueue(runtime, { lastExecutionResults: null, + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId, conversationId, @@ -397,7 +518,7 @@ describe("Path B: cancel during approval wait → next turn consumes synthesized // Next user message: consume const consumed = consumeInterruptQueue(runtime, agentId, conversationId); expect(consumed).not.toBeNull(); - expect(consumed?.approvals).toHaveLength(2); + expect(consumed?.approvalMessage.approvals).toHaveLength(2); // Queue cleared expect(runtime.pendingInterruptedResults).toBeNull(); @@ -409,6 +530,7 @@ describe("Path B: cancel during approval wait → next turn consumes synthesized // No batch map entries, but we have the snapshot IDs const populated = populateInterruptQueue(runtime, { lastExecutionResults: null, + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: ["call-a", "call-b"], agentId: "agent-1", conversationId: "conv-1", @@ -427,6 +549,7 @@ describe("Path B: cancel during approval wait → next turn consumes synthesized const populated = populateInterruptQueue(runtime, { lastExecutionResults: null, + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-1", conversationId: "conv-1", @@ -453,6 +576,7 @@ describe("post-cancel next turn: queue consumed exactly once (no error loop)", ( reason: "cancelled", }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId, conversationId: convId, @@ -476,6 +600,7 @@ describe("post-cancel next turn: queue consumed exactly once (no error loop)", ( lastExecutionResults: [ { type: "approval", tool_call_id: "call-1", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId, conversationId: convId, @@ -496,6 +621,7 @@ describe("idempotency: first cancel populates, second is no-op", () => { lastExecutionResults: [ { type: "approval", tool_call_id: "call-first", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-1", conversationId: "conv-1", @@ -511,6 +637,7 @@ describe("idempotency: first cancel populates, second is no-op", () => { reason: "x", }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-1", conversationId: "conv-1", @@ -530,6 +657,7 @@ describe("idempotency: first cancel populates, second is no-op", () => { lastExecutionResults: [ { type: "approval", tool_call_id: "call-1", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-1", conversationId: "conv-1", @@ -543,6 +671,7 @@ describe("idempotency: first cancel populates, second is no-op", () => { lastExecutionResults: [ { type: "approval", tool_call_id: "call-2", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-1", conversationId: "conv-1", @@ -564,6 +693,7 @@ describe("epoch guard: stale context discarded on consume", () => { lastExecutionResults: [ { type: "approval", tool_call_id: "call-1", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-1", conversationId: "conv-1", @@ -587,6 +717,7 @@ describe("epoch guard: stale context discarded on consume", () => { lastExecutionResults: [ { type: "approval", tool_call_id: "call-1", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-old", conversationId: "conv-1", @@ -605,6 +736,7 @@ describe("epoch guard: stale context discarded on consume", () => { lastExecutionResults: [ { type: "approval", tool_call_id: "call-1", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-1", conversationId: "conv-old", @@ -623,6 +755,7 @@ describe("stale Path-B IDs: clearing after successful send prevents re-denial", // Also batch map should be cleared by clearPendingApprovalBatchIds const populated = populateInterruptQueue(runtime, { lastExecutionResults: null, + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], // cleared after send agentId: "agent-1", conversationId: "conv-1", @@ -640,6 +773,7 @@ describe("stale Path-B IDs: clearing after successful send prevents re-denial", const populated = populateInterruptQueue(runtime, { lastExecutionResults: null, + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], // cleared from previous send agentId: "agent-1", conversationId: "conv-1", @@ -716,6 +850,7 @@ describe("consume clears pendingApprovalBatchByToolCallId", () => { lastExecutionResults: [ { type: "approval", tool_call_id: "call-1", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-1", conversationId: "conv-1", @@ -734,6 +869,7 @@ describe("consume clears pendingApprovalBatchByToolCallId", () => { lastExecutionResults: [ { type: "approval", tool_call_id: "call-1", approve: true }, ], + lastExecutingToolCallIds: [], lastNeedsUserInputToolCallIds: [], agentId: "agent-old", conversationId: "conv-old", diff --git a/src/websocket/listen-client.ts b/src/websocket/listen-client.ts index 774e079..e9ba01a 100644 --- a/src/websocket/listen-client.ts +++ b/src/websocket/listen-client.ts @@ -18,6 +18,7 @@ import { executeApprovalBatch, } from "../agent/approval-execution"; import { fetchRunErrorDetail } from "../agent/approval-recovery"; +import { normalizeApprovalResultsForPersistence } from "../agent/approval-result-normalization"; import { getResumeData } from "../agent/check-approval"; import { getClient } from "../agent/client"; import { getStreamToolContextId, sendMessageStream } from "../agent/message"; @@ -34,6 +35,7 @@ import { createBuffers } from "../cli/helpers/accumulator"; import { classifyApprovals } from "../cli/helpers/approvalClassification"; import { generatePlanFilePath } from "../cli/helpers/planName"; import { drainStreamWithResume } from "../cli/helpers/stream"; +import { INTERRUPTED_BY_USER } from "../constants"; import { computeDiffPreviews } from "../helpers/diffPreview"; import { permissionMode } from "../permissions/mode"; import { type QueueItem, QueueRuntime } from "../queue/queueRuntime"; @@ -306,6 +308,16 @@ type ListenerRuntime = { } | null; /** Monotonic epoch for queued continuation validity checks. */ continuationEpoch: number; + /** + * Tool call ids currently executing in the active approval loop turn. + * Used for eager cancel-time interrupt capture parity with App/headless. + */ + activeExecutingToolCallIds: string[]; + /** + * Structured interrupted tool_call_ids carried with queued interrupt approvals. + * Threaded into the next send for persistence normalization. + */ + pendingInterruptedToolCallIds: string[] | null; }; type ApprovalSlot = @@ -382,6 +394,8 @@ function createRuntime(): ListenerRuntime { pendingInterruptedResults: null, pendingInterruptedContext: null, continuationEpoch: 0, + activeExecutingToolCallIds: [], + pendingInterruptedToolCallIds: null, coalescedSkipQueueItemIds: new Set(), pendingTurns: 0, // queueRuntime assigned below — needs runtime ref in callbacks @@ -559,6 +573,8 @@ function stopRuntime( // Clear interrupted queue on true teardown to prevent cross-session leakage. runtime.pendingInterruptedResults = null; runtime.pendingInterruptedContext = null; + runtime.pendingInterruptedToolCallIds = null; + runtime.activeExecutingToolCallIds = []; runtime.continuationEpoch++; if (!runtime.socket) { @@ -907,6 +923,7 @@ function shouldAttemptPostStopApprovalRecovery(params: { interface InterruptPopulateInput { lastExecutionResults: ApprovalResult[] | null; + lastExecutingToolCallIds: string[]; lastNeedsUserInputToolCallIds: string[]; agentId: string; conversationId: string; @@ -920,10 +937,48 @@ interface InterruptToolReturn { stderr?: string[]; } +function asToolReturnStatus(value: unknown): "success" | "error" | null { + if (value === "success" || value === "error") { + return value; + } + return null; +} + function normalizeToolReturnValue(value: unknown): string { if (typeof value === "string") { return value; } + if (Array.isArray(value)) { + const textParts = value + .filter( + ( + part, + ): part is { + type: string; + text: string; + } => + !!part && + typeof part === "object" && + "type" in part && + part.type === "text" && + "text" in part && + typeof part.text === "string", + ) + .map((part) => part.text); + if (textParts.length > 0) { + return textParts.join("\n"); + } + } + if ( + value && + typeof value === "object" && + "type" in value && + value.type === "text" && + "text" in value && + typeof value.text === "string" + ) { + return value.text; + } if (value === null || value === undefined) { return ""; } @@ -934,6 +989,130 @@ function normalizeToolReturnValue(value: unknown): string { } } +function normalizeInterruptedApprovalsForQueue( + approvals: ApprovalResult[] | null, + interruptedToolCallIds: string[], +): ApprovalResult[] | null { + if (!approvals || approvals.length === 0) { + return approvals; + } + + return normalizeApprovalResultsForPersistence(approvals, { + interruptedToolCallIds, + // Temporary fallback guard while all producers migrate to structured IDs. + allowInterruptTextFallback: true, + }); +} + +function normalizeExecutionResultsForInterruptParity( + runtime: ListenerRuntime, + executionResults: ApprovalResult[], + executingToolCallIds: string[], +): ApprovalResult[] { + if (!runtime.cancelRequested || executionResults.length === 0) { + return executionResults; + } + + return normalizeApprovalResultsForPersistence(executionResults, { + interruptedToolCallIds: executingToolCallIds, + }); +} + +function extractCanonicalToolReturnsFromWire( + payload: Record, +): InterruptToolReturn[] { + const fromArray: InterruptToolReturn[] = []; + const toolReturnsValue = payload.tool_returns; + if (Array.isArray(toolReturnsValue)) { + for (const raw of toolReturnsValue) { + if (!raw || typeof raw !== "object") { + continue; + } + const rec = raw as Record; + const toolCallId = + typeof rec.tool_call_id === "string" ? rec.tool_call_id : null; + const status = asToolReturnStatus(rec.status); + if (!toolCallId || !status) { + continue; + } + const stdout = Array.isArray(rec.stdout) + ? rec.stdout.filter( + (entry): entry is string => typeof entry === "string", + ) + : undefined; + const stderr = Array.isArray(rec.stderr) + ? rec.stderr.filter( + (entry): entry is string => typeof entry === "string", + ) + : undefined; + fromArray.push({ + tool_call_id: toolCallId, + status, + tool_return: normalizeToolReturnValue(rec.tool_return), + ...(stdout ? { stdout } : {}), + ...(stderr ? { stderr } : {}), + }); + } + } + if (fromArray.length > 0) { + return fromArray; + } + + const topLevelToolCallId = + typeof payload.tool_call_id === "string" ? payload.tool_call_id : null; + const topLevelStatus = asToolReturnStatus(payload.status); + if (!topLevelToolCallId || !topLevelStatus) { + return []; + } + const stdout = Array.isArray(payload.stdout) + ? payload.stdout.filter( + (entry): entry is string => typeof entry === "string", + ) + : undefined; + const stderr = Array.isArray(payload.stderr) + ? payload.stderr.filter( + (entry): entry is string => typeof entry === "string", + ) + : undefined; + return [ + { + tool_call_id: topLevelToolCallId, + status: topLevelStatus, + tool_return: normalizeToolReturnValue(payload.tool_return), + ...(stdout ? { stdout } : {}), + ...(stderr ? { stderr } : {}), + }, + ]; +} + +function normalizeToolReturnWireMessage( + chunk: Record, +): Record | null { + if (chunk.message_type !== "tool_return_message") { + return chunk; + } + + const canonicalToolReturns = extractCanonicalToolReturnsFromWire(chunk); + if (canonicalToolReturns.length === 0) { + return null; + } + + const { + tool_call_id: _toolCallId, + status: _status, + tool_return: _toolReturn, + stdout: _stdout, + stderr: _stderr, + ...rest + } = chunk; + + return { + ...rest, + message_type: "tool_return_message", + tool_returns: canonicalToolReturns, + }; +} + function extractInterruptToolReturns( approvals: ApprovalResult[] | null, ): InterruptToolReturn[] { @@ -1030,12 +1209,15 @@ function emitInterruptToolReturnMessage( id: `message-${crypto.randomUUID()}`, date: new Date().toISOString(), run_id: resolvedRunId, - tool_call_id: toolReturn.tool_call_id, - tool_return: toolReturn.tool_return, - status: toolReturn.status, - ...(toolReturn.stdout ? { stdout: toolReturn.stdout } : {}), - ...(toolReturn.stderr ? { stderr: toolReturn.stderr } : {}), - tool_returns: [toolReturn], + tool_returns: [ + { + tool_call_id: toolReturn.tool_call_id, + status: toolReturn.status, + tool_return: toolReturn.tool_return, + ...(toolReturn.stdout ? { stdout: toolReturn.stdout } : {}), + ...(toolReturn.stderr ? { stderr: toolReturn.stderr } : {}), + }, + ], session_id: runtime.sessionId, uuid: `${uuidPrefix}-${crypto.randomUUID()}`, } as unknown as MessageWire); @@ -1092,12 +1274,38 @@ function populateInterruptQueue( if (input.lastExecutionResults && input.lastExecutionResults.length > 0) { // Path A: execution happened before cancel — queue actual results - runtime.pendingInterruptedResults = input.lastExecutionResults; + // Guard parity: interrupted tool returns must persist as status=error. + runtime.pendingInterruptedResults = normalizeInterruptedApprovalsForQueue( + input.lastExecutionResults, + input.lastExecutingToolCallIds, + ); runtime.pendingInterruptedContext = { agentId: input.agentId, conversationId: input.conversationId, continuationEpoch: runtime.continuationEpoch, }; + runtime.pendingInterruptedToolCallIds = [...input.lastExecutingToolCallIds]; + return true; + } + + // Path A.5: execution was in-flight (approved tools started) but no + // terminal results were captured before cancel. Match App/headless parity by + // queuing explicit tool errors, not synthetic approval denials. + if (input.lastExecutingToolCallIds.length > 0) { + runtime.pendingInterruptedResults = input.lastExecutingToolCallIds.map( + (toolCallId) => ({ + type: "tool" as const, + tool_call_id: toolCallId, + tool_return: INTERRUPTED_BY_USER, + status: "error" as const, + }), + ); + runtime.pendingInterruptedContext = { + agentId: input.agentId, + conversationId: input.conversationId, + continuationEpoch: runtime.continuationEpoch, + }; + runtime.pendingInterruptedToolCallIds = [...input.lastExecutingToolCallIds]; return true; } @@ -1120,6 +1328,7 @@ function populateInterruptQueue( conversationId: input.conversationId, continuationEpoch: runtime.continuationEpoch, }; + runtime.pendingInterruptedToolCallIds = null; return true; } @@ -1144,7 +1353,10 @@ function consumeInterruptQueue( runtime: ListenerRuntime, agentId: string, conversationId: string, -): { type: "approval"; approvals: ApprovalResult[] } | null { +): { + approvalMessage: { type: "approval"; approvals: ApprovalResult[] }; + interruptedToolCallIds: string[]; +} | null { if ( !runtime.pendingInterruptedResults || runtime.pendingInterruptedResults.length === 0 @@ -1153,7 +1365,10 @@ function consumeInterruptQueue( } const ctx = runtime.pendingInterruptedContext; - let result: { type: "approval"; approvals: ApprovalResult[] } | null = null; + let result: { + approvalMessage: { type: "approval"; approvals: ApprovalResult[] }; + interruptedToolCallIds: string[]; + } | null = null; if ( ctx && @@ -1162,8 +1377,13 @@ function consumeInterruptQueue( ctx.continuationEpoch === runtime.continuationEpoch ) { result = { - type: "approval", - approvals: runtime.pendingInterruptedResults, + approvalMessage: { + type: "approval", + approvals: runtime.pendingInterruptedResults, + }, + interruptedToolCallIds: runtime.pendingInterruptedToolCallIds + ? [...runtime.pendingInterruptedToolCallIds] + : [], }; } @@ -1171,6 +1391,7 @@ function consumeInterruptQueue( // Stale results for wrong context are discarded, not retried. runtime.pendingInterruptedResults = null; runtime.pendingInterruptedContext = null; + runtime.pendingInterruptedToolCallIds = null; runtime.pendingApprovalBatchByToolCallId.clear(); return result; @@ -1963,6 +2184,30 @@ async function connectWithRetry( } runtime.cancelRequested = true; + // Eager interrupt capture parity with App/headless: + // if tool execution is currently in-flight, queue explicit interrupted + // tool results immediately at cancel time (before async catch paths). + if ( + runtime.activeExecutingToolCallIds.length > 0 && + (!runtime.pendingInterruptedResults || + runtime.pendingInterruptedResults.length === 0) + ) { + runtime.pendingInterruptedResults = + runtime.activeExecutingToolCallIds.map((toolCallId) => ({ + type: "tool", + tool_call_id: toolCallId, + tool_return: INTERRUPTED_BY_USER, + status: "error", + })); + runtime.pendingInterruptedContext = { + agentId: runtime.activeAgentId || "", + conversationId: runtime.activeConversationId || "default", + continuationEpoch: runtime.continuationEpoch, + }; + runtime.pendingInterruptedToolCallIds = [ + ...runtime.activeExecutingToolCallIds, + ]; + } if ( runtime.activeAbortController && !runtime.activeAbortController.signal.aborted @@ -2288,6 +2533,7 @@ async function handleIncomingMessage( // Track last approval-loop state for cancel-time queueing (Phase 1.2). // Hoisted before try so the cancel catch block can access them. let lastExecutionResults: ApprovalResult[] | null = null; + let lastExecutingToolCallIds: string[] = []; let lastNeedsUserInputToolCallIds: string[] = []; runtime.isProcessing = true; @@ -2297,6 +2543,7 @@ async function handleIncomingMessage( runtime.activeConversationId = conversationId; runtime.activeRunId = null; runtime.activeRunStartedAt = new Date().toISOString(); + runtime.activeExecutingToolCallIds = []; try { // Latch capability: once seen, always use blocking path (strict check to avoid truthy strings) @@ -2322,6 +2569,8 @@ async function handleIncomingMessage( let messagesToSend: Array = []; let turnToolContextId: string | null = null; + let queuedInterruptedToolCallIds: string[] = []; + let shouldClearSubmittedApprovalTracking = false; // Prepend queued interrupted results from a prior cancelled turn. const consumed = consumeInterruptQueue( @@ -2330,7 +2579,8 @@ async function handleIncomingMessage( conversationId, ); if (consumed) { - messagesToSend.push(consumed); + messagesToSend.push(consumed.approvalMessage); + queuedInterruptedToolCallIds = consumed.interruptedToolCallIds; } messagesToSend.push(...msg.messages); @@ -2362,12 +2612,29 @@ async function handleIncomingMessage( approvalMessage, resumeData.pendingApprovals, ); + lastExecutingToolCallIds = decisions + .filter( + ( + decision, + ): decision is Extract => + decision.type === "approve", + ) + .map((decision) => decision.approval.toolCallId); + runtime.activeExecutingToolCallIds = [...lastExecutingToolCallIds]; + const decisionResults = decisions.length > 0 ? await executeApprovalBatch(decisions, undefined, { toolContextId: turnToolContextId ?? undefined, + abortSignal: runtime.activeAbortController.signal, }) : []; + const persistedDecisionResults = + normalizeExecutionResultsForInterruptParity( + runtime, + decisionResults, + lastExecutingToolCallIds, + ); const rebuiltApprovals: ApprovalResult[] = []; let decisionResultIndex = 0; @@ -2378,7 +2645,7 @@ async function handleIncomingMessage( continue; } - const next = decisionResults[decisionResultIndex]; + const next = persistedDecisionResults[decisionResultIndex]; if (next) { rebuiltApprovals.push(next); decisionResultIndex++; @@ -2393,6 +2660,8 @@ async function handleIncomingMessage( }); } + lastExecutionResults = rebuiltApprovals; + shouldClearSubmittedApprovalTracking = true; messagesToSend = [ { type: "approval", @@ -2411,15 +2680,33 @@ async function handleIncomingMessage( } let currentInput = messagesToSend; + const sendOptions: Parameters[2] = { + agentId, + streamTokens: true, + background: true, + ...(queuedInterruptedToolCallIds.length > 0 + ? { + approvalNormalization: { + interruptedToolCallIds: queuedInterruptedToolCallIds, + }, + } + : {}), + }; let stream = await sendMessageStreamWithRetry( conversationId, currentInput, - { agentId, streamTokens: true, background: true }, + sendOptions, socket, runtime, runtime.activeAbortController.signal, ); + if (shouldClearSubmittedApprovalTracking) { + lastExecutionResults = null; + lastExecutingToolCallIds = []; + lastNeedsUserInputToolCallIds = []; + runtime.activeExecutingToolCallIds = []; + } turnToolContextId = getStreamToolContextId( stream as Stream, @@ -2477,12 +2764,18 @@ async function handleIncomingMessage( otid?: string; id?: string; }; - emitToWS(socket, { - ...chunk, - type: "message", - session_id: runtime.sessionId, - uuid: chunkWithIds.otid || chunkWithIds.id || crypto.randomUUID(), - } as MessageWire); + const normalizedChunk = normalizeToolReturnWireMessage( + chunk as unknown as Record, + ); + if (normalizedChunk) { + emitToWS(socket, { + ...normalizedChunk, + type: "message", + session_id: runtime.sessionId, + uuid: + chunkWithIds.otid || chunkWithIds.id || crypto.randomUUID(), + } as unknown as MessageWire); + } } return undefined; @@ -2603,7 +2896,7 @@ async function handleIncomingMessage( stream = await sendMessageStreamWithRetry( conversationId, currentInput, - { agentId, streamTokens: true, background: true }, + sendOptions, socket, runtime, runtime.activeAbortController.signal, @@ -2875,6 +3168,16 @@ async function handleIncomingMessage( } } + // Snapshot executing tool_call_ids before execution starts so cancel can + // preserve tool-error parity even if execution aborts mid-await. + lastExecutingToolCallIds = decisions + .filter( + (decision): decision is Extract => + decision.type === "approve", + ) + .map((decision) => decision.approval.toolCallId); + runtime.activeExecutingToolCallIds = [...lastExecutingToolCallIds]; + // Execute approved/denied tools const executionResults = await executeApprovalBatch( decisions, @@ -2884,13 +3187,19 @@ async function handleIncomingMessage( abortSignal: runtime.activeAbortController.signal, }, ); - lastExecutionResults = executionResults; + const persistedExecutionResults = + normalizeExecutionResultsForInterruptParity( + runtime, + executionResults, + lastExecutingToolCallIds, + ); + lastExecutionResults = persistedExecutionResults; // WS-first parity: publish tool-return terminal outcomes immediately on // normal approval execution, before continuation stream send. emitInterruptToolReturnMessage( socket, runtime, - executionResults, + persistedExecutionResults, runtime.activeRunId || runId || msgRunIds[msgRunIds.length - 1] || @@ -2906,13 +3215,13 @@ async function handleIncomingMessage( currentInput = [ { type: "approval", - approvals: executionResults, + approvals: persistedExecutionResults, }, ]; stream = await sendMessageStreamWithRetry( conversationId, currentInput, - { agentId, streamTokens: true, background: true }, + sendOptions, socket, runtime, runtime.activeAbortController.signal, @@ -2922,7 +3231,9 @@ async function handleIncomingMessage( // cancel during the subsequent stream drain won't queue already-sent // results (Path A) or re-deny already-resolved tool calls (Path B). lastExecutionResults = null; + lastExecutingToolCallIds = []; lastNeedsUserInputToolCallIds = []; + runtime.activeExecutingToolCallIds = []; turnToolContextId = getStreamToolContextId( stream as Stream, @@ -2933,6 +3244,7 @@ async function handleIncomingMessage( // Queue interrupted tool-call resolutions for the next message turn. populateInterruptQueue(runtime, { lastExecutionResults, + lastExecutingToolCallIds, lastNeedsUserInputToolCallIds, agentId: agentId || "", conversationId, @@ -3042,6 +3354,7 @@ async function handleIncomingMessage( } finally { runtime.activeAbortController = null; runtime.cancelRequested = false; + runtime.activeExecutingToolCallIds = []; } } @@ -3078,5 +3391,7 @@ export const __listenClientTestUtils = { extractInterruptToolReturns, emitInterruptToolReturnMessage, getInterruptApprovalsForEmission, + normalizeToolReturnWireMessage, + normalizeExecutionResultsForInterruptParity, shouldAttemptPostStopApprovalRecovery, };