fix(listen): enforce strict queue/run correlation in recovery

This commit is contained in:
cpacker
2026-03-01 00:06:37 -08:00
parent ac4621d91d
commit b4910cd410
7 changed files with 342 additions and 38 deletions

View File

@@ -2496,6 +2496,7 @@ async function runBidirectionalMode(
emitQueueEvent({
type: "queue_item_enqueued",
item_id: item.id,
client_message_id: item.clientMessageId ?? `cm-${item.id}`,
source: item.source,
kind: item.kind,
queue_len: queueLen,

View File

@@ -14,6 +14,8 @@ export type { QueueBlockedReason, QueueClearedReason, QueueItemKind };
type QueueItemBase = {
/** Stable monotonic ID assigned on enqueue. */
id: string;
/** Optional client-side message correlation ID from submit payloads. */
clientMessageId?: string;
source: QueueItemSource;
enqueuedAt: number;
};

View File

@@ -25,6 +25,7 @@ describe("QueueItemEnqueuedEvent wire shape", () => {
const event: QueueItemEnqueuedEvent = {
type: "queue_item_enqueued",
item_id: "item-1",
client_message_id: "cm-item-1",
source: "user",
kind: "message",
queue_len: 1,
@@ -184,6 +185,7 @@ describe("QueueLifecycleEvent union", () => {
{
type: "queue_item_enqueued",
item_id: "i1",
client_message_id: "cm-i1",
source: "user",
kind: "message",
queue_len: 1,
@@ -237,6 +239,7 @@ describe("QueueLifecycleEvent union", () => {
const event: QueueLifecycleEvent = {
type: "queue_item_enqueued",
item_id: "i1",
client_message_id: "cm-i1",
source: "task_notification",
kind: "task_notification",
queue_len: 2,
@@ -258,6 +261,7 @@ describe("QueueLifecycleEvent union", () => {
const event: QueueLifecycleEvent = {
type: "queue_item_enqueued",
item_id: "i1",
client_message_id: "cm-i1",
source: "user",
kind: "message",
queue_len: 1,

View File

@@ -377,6 +377,79 @@ describe("listen-client capability-gated approval flow", () => {
});
});
describe("listen-client approval recovery batch correlation", () => {
test("resolves the original batch id from pending tool call ids", () => {
const runtime = __listenClientTestUtils.createRuntime();
__listenClientTestUtils.rememberPendingApprovalBatchIds(
runtime,
[{ toolCallId: "tool-1" }, { toolCallId: "tool-2" }],
"batch-123",
);
expect(
__listenClientTestUtils.resolvePendingApprovalBatchId(runtime, [
{ toolCallId: "tool-1" },
{ toolCallId: "tool-2" },
]),
).toBe("batch-123");
});
test("returns null when pending approvals map to multiple batches", () => {
const runtime = __listenClientTestUtils.createRuntime();
__listenClientTestUtils.rememberPendingApprovalBatchIds(
runtime,
[{ toolCallId: "tool-a" }],
"batch-a",
);
__listenClientTestUtils.rememberPendingApprovalBatchIds(
runtime,
[{ toolCallId: "tool-b" }],
"batch-b",
);
expect(
__listenClientTestUtils.resolvePendingApprovalBatchId(runtime, [
{ toolCallId: "tool-a" },
{ toolCallId: "tool-b" },
]),
).toBeNull();
});
test("returns null when one pending approval mapping is missing", () => {
const runtime = __listenClientTestUtils.createRuntime();
__listenClientTestUtils.rememberPendingApprovalBatchIds(
runtime,
[{ toolCallId: "tool-a" }],
"batch-a",
);
expect(
__listenClientTestUtils.resolvePendingApprovalBatchId(runtime, [
{ toolCallId: "tool-a" },
{ toolCallId: "tool-missing" },
]),
).toBeNull();
});
test("clears correlation after approvals are executed", () => {
const runtime = __listenClientTestUtils.createRuntime();
__listenClientTestUtils.rememberPendingApprovalBatchIds(
runtime,
[{ toolCallId: "tool-x" }],
"batch-x",
);
__listenClientTestUtils.clearPendingApprovalBatchIds(runtime, [
{ toolCallId: "tool-x" },
]);
expect(
__listenClientTestUtils.resolvePendingApprovalBatchId(runtime, [
{ toolCallId: "tool-x" },
]),
).toBeNull();
});
});
describe("listen-client emitToWS adapter", () => {
test("sends event when socket is OPEN", () => {
const socket = new MockSocket(WebSocket.OPEN);

View File

@@ -67,29 +67,49 @@ function simulateMessageArrival(
q: QueueRuntime,
pendingTurnsRef: { value: number },
payload: MessageCreate | ApprovalCreate,
): boolean {
): { isUserMessage: boolean; queueItemId?: string } {
const isUserMessage = "content" in payload;
let queueItemId: string | undefined;
if (isUserMessage) {
q.enqueue({
const enqueued = q.enqueue({
kind: "message",
source: "user",
content: (payload as MessageCreate).content,
} as Parameters<typeof q.enqueue>[0]);
queueItemId = enqueued?.id;
if (pendingTurnsRef.value > 0) {
q.tryDequeue("runtime_busy");
}
}
pendingTurnsRef.value++; // synchronous before .then()
return isUserMessage;
return { isUserMessage, queueItemId };
}
/** Mirrors the start of the .then() chain callback. */
function simulateTurnStart(
q: QueueRuntime,
_pendingTurnsRef: { value: number },
isUserMessage: boolean,
arrival: { isUserMessage: boolean; queueItemId?: string },
skipIds: Set<string>,
): void {
if (isUserMessage) q.consumeItems(1);
if (!arrival.isUserMessage || !arrival.queueItemId) {
return;
}
if (skipIds.has(arrival.queueItemId)) {
skipIds.delete(arrival.queueItemId);
return;
}
const batch = q.tryDequeue(null);
if (!batch) {
return;
}
for (const item of batch.items) {
if (item.id !== arrival.queueItemId) {
skipIds.add(item.id);
}
}
}
/** Mirrors the finally block. */
@@ -116,13 +136,14 @@ describe("single message — idle path", () => {
test("enqueued → dequeued, no blocked, real queue_len values", () => {
const { q, rec } = buildRuntime();
const turns = { value: 0 };
const skipIds = new Set<string>();
const isUser = simulateMessageArrival(q, turns, makeMessageCreate());
const firstArrival = simulateMessageArrival(q, turns, makeMessageCreate());
expect(rec.enqueued).toHaveLength(1);
expect(rec.enqueued.at(0)?.queueLen).toBe(1);
expect(rec.blocked).toHaveLength(0);
simulateTurnStart(q, turns, isUser);
simulateTurnStart(q, turns, firstArrival, skipIds);
expect(rec.dequeued).toHaveLength(1);
expect(rec.dequeued.at(0)?.mergedCount).toBe(1);
expect(rec.dequeued.at(0)?.queueLenAfter).toBe(0);
@@ -137,9 +158,10 @@ describe("two rapid messages — busy path", () => {
test("second arrival gets blocked(runtime_busy) due to sync pendingTurns", () => {
const { q, rec } = buildRuntime();
const turns = { value: 0 };
const skipIds = new Set<string>();
// First message arrives
const isUser1 = simulateMessageArrival(
const arrival1 = simulateMessageArrival(
q,
turns,
makeMessageCreate("first"),
@@ -148,7 +170,7 @@ describe("two rapid messages — busy path", () => {
expect(rec.blocked).toHaveLength(0); // was 0 at arrival
// Second message arrives BEFORE first turn's .then() runs
const isUser2 = simulateMessageArrival(
const arrival2 = simulateMessageArrival(
q,
turns,
makeMessageCreate("second"),
@@ -159,17 +181,16 @@ describe("two rapid messages — busy path", () => {
expect(rec.blocked.at(0)?.queueLen).toBe(2); // both enqueued
// First turn runs
simulateTurnStart(q, turns, isUser1);
simulateTurnStart(q, turns, arrival1, skipIds);
expect(rec.dequeued).toHaveLength(1);
expect(rec.dequeued.at(0)?.mergedCount).toBe(1);
expect(rec.dequeued.at(0)?.mergedCount).toBe(2);
expect(rec.dequeued.at(0)?.queueLenAfter).toBe(0);
simulateTurnEnd(q, turns);
expect(turns.value).toBe(1); // second still pending
// Second turn runs
simulateTurnStart(q, turns, isUser2);
expect(rec.dequeued).toHaveLength(2);
expect(rec.dequeued.at(1)?.mergedCount).toBe(1);
expect(rec.dequeued.at(1)?.queueLenAfter).toBe(0);
// Second callback no-ops (item already consumed in coalesced batch).
simulateTurnStart(q, turns, arrival2, skipIds);
expect(rec.dequeued).toHaveLength(1);
simulateTurnEnd(q, turns);
expect(turns.value).toBe(0);
});
@@ -177,15 +198,21 @@ describe("two rapid messages — busy path", () => {
test("blocked fires only once for same reason; resets when fully drained", () => {
const { q, rec } = buildRuntime();
const turns = { value: 0 };
const skipIds = new Set<string>();
simulateMessageArrival(q, turns, makeMessageCreate("a"));
simulateMessageArrival(q, turns, makeMessageCreate("b")); // blocked
simulateMessageArrival(q, turns, makeMessageCreate("c")); // same reason — no extra blocked
const arrivalA = simulateMessageArrival(q, turns, makeMessageCreate("a"));
const arrivalB = simulateMessageArrival(q, turns, makeMessageCreate("b")); // blocked
const arrivalC = simulateMessageArrival(q, turns, makeMessageCreate("c")); // same reason — no extra blocked
expect(rec.blocked).toHaveLength(1);
// Drain all three
const queuedArrivals = [arrivalA, arrivalB, arrivalC];
for (let i = 0; i < 3; i++) {
simulateTurnStart(q, turns, true);
const queuedArrival = queuedArrivals[i];
if (!queuedArrival) {
continue;
}
simulateTurnStart(q, turns, queuedArrival, skipIds);
simulateTurnEnd(q, turns);
}
expect(turns.value).toBe(0);
@@ -203,12 +230,13 @@ describe("pendingTurns safety — always decremented", () => {
// (finally equivalent) always restores pendingTurns to 0.
const { q } = buildRuntime();
const turns = { value: 0 };
const skipIds = new Set<string>();
simulateMessageArrival(q, turns, makeMessageCreate("msg"));
const arrival = simulateMessageArrival(q, turns, makeMessageCreate("msg"));
expect(turns.value).toBe(1);
// Simulate: consumeItems fires, then an error before handleIncomingMessage
q.consumeItems(1);
simulateTurnStart(q, turns, arrival, skipIds);
// finally fires (error path)
simulateTurnEnd(q, turns);
expect(turns.value).toBe(0); // not leaked
@@ -220,14 +248,15 @@ describe("ApprovalCreate payloads", () => {
test("ApprovalCreate is not enqueued (no content field)", () => {
const { q, rec } = buildRuntime();
const turns = { value: 0 };
const skipIds = new Set<string>();
const isUser = simulateMessageArrival(q, turns, makeApprovalCreate());
expect(isUser).toBe(false);
const arrival = simulateMessageArrival(q, turns, makeApprovalCreate());
expect(arrival.isUserMessage).toBe(false);
expect(rec.enqueued).toHaveLength(0);
expect(turns.value).toBe(1); // pendingTurns still increments
// No consumeItems called in .then()
simulateTurnStart(q, turns, isUser);
simulateTurnStart(q, turns, arrival, skipIds);
expect(rec.dequeued).toHaveLength(0);
simulateTurnEnd(q, turns);
expect(turns.value).toBe(0);
@@ -257,18 +286,27 @@ describe("per-turn error — no queue_cleared", () => {
test("turn error only decrements pendingTurns; remaining turns still dequeue", () => {
const { q, rec } = buildRuntime();
const turns = { value: 0 };
const skipIds = new Set<string>();
simulateMessageArrival(q, turns, makeMessageCreate("first"));
simulateMessageArrival(q, turns, makeMessageCreate("second"));
const arrival1 = simulateMessageArrival(
q,
turns,
makeMessageCreate("first"),
);
const arrival2 = simulateMessageArrival(
q,
turns,
makeMessageCreate("second"),
);
// First turn: simulate error — finally still runs
simulateTurnStart(q, turns, true);
simulateTurnStart(q, turns, arrival1, skipIds);
simulateTurnEnd(q, turns); // error path still hits finally
expect(rec.cleared).toHaveLength(0); // no queue_cleared
// Second turn still runs
simulateTurnStart(q, turns, true);
expect(rec.dequeued).toHaveLength(2);
// Second callback no-ops; first turn already consumed coalesced batch.
simulateTurnStart(q, turns, arrival2, skipIds);
expect(rec.dequeued).toHaveLength(1);
simulateTurnEnd(q, turns);
expect(turns.value).toBe(0);
expect(rec.cleared).toHaveLength(0); // still no queue_cleared

View File

@@ -314,6 +314,8 @@ export interface QueueItemEnqueuedEvent extends MessageEnvelope {
id?: string;
/** @deprecated Use `id`. */
item_id: string;
/** Correlates this queue item back to the originating client submit payload. */
client_message_id: string;
source: QueueItemSource;
kind: QueueItemKind;
/** Full queue item content; renderers may truncate for display. */

View File

@@ -32,6 +32,7 @@ import { drainStreamWithResume } from "../cli/helpers/stream";
import { computeDiffPreviews } from "../helpers/diffPreview";
import { permissionMode } from "../permissions/mode";
import { type QueueItem, QueueRuntime } from "../queue/queueRuntime";
import { mergeQueuedTurnInput } from "../queue/turnQueueRuntime";
import { settingsManager } from "../settings-manager";
import { isInteractiveApprovalTool } from "../tools/interactivePolicy";
import { loadTools } from "../tools/manager";
@@ -100,7 +101,9 @@ interface IncomingMessage {
type: "message";
agentId?: string;
conversationId?: string;
messages: Array<MessageCreate | ApprovalCreate>;
messages: Array<
(MessageCreate & { client_message_id?: string }) | ApprovalCreate
>;
/** Cloud sets this when it supports can_use_tool / control_response protocol. */
supportsControlResponse?: boolean;
}
@@ -116,6 +119,7 @@ interface ResultMessage {
interface RunStartedMessage {
type: "run_started";
runId: string;
batch_id: string;
event_seq?: number;
session_id?: string;
}
@@ -194,6 +198,7 @@ interface StateResponseMessage {
pending_turns: number;
items: Array<{
id: string;
client_message_id: string;
kind: string;
source: string;
content: unknown;
@@ -256,6 +261,11 @@ type ListenerRuntime = {
cancelRequested: boolean;
/** Queue lifecycle tracking — parallel tracking layer, does not affect message processing. */
queueRuntime: QueueRuntime;
/**
* Queue item IDs that were coalesced into an earlier dequeued batch.
* Their already-scheduled promise-chain callbacks should no-op.
*/
coalescedSkipQueueItemIds: Set<string>;
/** Count of turns currently queued or in-flight in the promise chain. Incremented
* synchronously on message arrival (before .then()) to avoid async scheduling races. */
pendingTurns: number;
@@ -263,6 +273,11 @@ type ListenerRuntime = {
onWsEvent?: StartListenerOptions["onWsEvent"];
/** Prevent duplicate concurrent pending-approval recovery passes. */
isRecoveringApprovals: boolean;
/**
* Correlates pending approval tool_call_ids to the originating dequeued batch.
* Used to preserve run attachment continuity across approval recovery.
*/
pendingApprovalBatchByToolCallId: Map<string, string>;
};
type ApprovalSlot =
@@ -335,6 +350,8 @@ function createRuntime(): ListenerRuntime {
activeAbortController: null,
cancelRequested: false,
isRecoveringApprovals: false,
pendingApprovalBatchByToolCallId: new Map<string, string>(),
coalescedSkipQueueItemIds: new Set<string>(),
pendingTurns: 0,
// queueRuntime assigned below — needs runtime ref in callbacks
queueRuntime: null as unknown as QueueRuntime,
@@ -348,6 +365,7 @@ function createRuntime(): ListenerRuntime {
type: "queue_item_enqueued",
id: item.id,
item_id: item.id,
client_message_id: item.clientMessageId ?? `cm-${item.id}`,
source: item.source,
kind: item.kind,
content,
@@ -430,6 +448,51 @@ function clearActiveRunState(runtime: ListenerRuntime): void {
runtime.activeAbortController = null;
}
function rememberPendingApprovalBatchIds(
runtime: ListenerRuntime,
pendingApprovals: Array<{ toolCallId: string }>,
batchId: string,
): void {
for (const approval of pendingApprovals) {
if (approval.toolCallId) {
runtime.pendingApprovalBatchByToolCallId.set(
approval.toolCallId,
batchId,
);
}
}
}
function resolvePendingApprovalBatchId(
runtime: ListenerRuntime,
pendingApprovals: Array<{ toolCallId: string }>,
): string | null {
const batchIds = new Set<string>();
for (const approval of pendingApprovals) {
const batchId = runtime.pendingApprovalBatchByToolCallId.get(
approval.toolCallId,
);
// Fail closed: every pending approval must have an originating batch mapping.
if (!batchId) {
return null;
}
batchIds.add(batchId);
}
if (batchIds.size !== 1) {
return null;
}
return batchIds.values().next().value ?? null;
}
function clearPendingApprovalBatchIds(
runtime: ListenerRuntime,
approvals: Array<{ toolCallId: string }>,
): void {
for (const approval of approvals) {
runtime.pendingApprovalBatchByToolCallId.delete(approval.toolCallId);
}
}
function stopRuntime(
runtime: ListenerRuntime,
suppressCallbacks: boolean,
@@ -444,6 +507,7 @@ function stopRuntime(
}
clearRuntimeTimers(runtime);
rejectPendingApprovalResolvers(runtime, "Listener runtime stopped");
runtime.pendingApprovalBatchByToolCallId.clear();
if (!runtime.socket) {
return;
@@ -537,12 +601,45 @@ function getQueueItemContent(item: QueueItem): unknown {
return item.kind === "message" ? item.content : item.text;
}
function mergeDequeuedBatchContent(
items: QueueItem[],
): MessageCreate["content"] | null {
const queuedInputs: Array<
| { kind: "user"; content: MessageCreate["content"] }
| {
kind: "task_notification";
text: string;
}
> = [];
for (const item of items) {
if (item.kind === "message") {
queuedInputs.push({
kind: "user",
content: item.content,
});
continue;
}
if (item.kind === "task_notification") {
queuedInputs.push({
kind: "task_notification",
text: item.text,
});
}
}
return mergeQueuedTurnInput(queuedInputs, {
normalizeUserContent: (content) => content,
});
}
function buildStateResponse(
runtime: ListenerRuntime,
stateSeq: number,
): StateResponseMessage {
const queueItems = runtime.queueRuntime.items.map((item) => ({
id: item.id,
client_message_id: item.clientMessageId ?? `cm-${item.id}`,
kind: item.kind,
source: item.source,
content: getQueueItemContent(item),
@@ -1011,6 +1108,23 @@ async function recoverPendingApprovals(
return;
}
const recoveryBatchId = resolvePendingApprovalBatchId(
runtime,
pendingApprovals,
);
if (!recoveryBatchId) {
emitToWS(socket, {
type: "error",
message:
"Unable to recover pending approvals without originating batch correlation",
stop_reason: "error",
session_id: runtime.sessionId,
uuid: `error-${crypto.randomUUID()}`,
});
runtime.lastStopReason = "requires_approval";
return;
}
type Decision =
| {
type: "approve";
@@ -1154,6 +1268,10 @@ async function recoverPendingApprovals(
}
const executionResults = await executeApprovalBatch(decisions);
clearPendingApprovalBatchIds(
runtime,
decisions.map((decision) => decision.approval),
);
await handleIncomingMessage(
{
@@ -1170,6 +1288,9 @@ async function recoverPendingApprovals(
},
socket,
runtime,
undefined,
undefined,
recoveryBatchId,
);
} finally {
runtime.isRecoveringApprovals = false;
@@ -1474,12 +1595,19 @@ async function connectWithRetry(
const firstPayload = parsed.messages.at(0);
const isUserMessage =
firstPayload !== undefined && "content" in firstPayload;
let enqueuedQueueItemId: string | null = null;
if (isUserMessage) {
runtime.queueRuntime.enqueue({
const userPayload = firstPayload as MessageCreate & {
client_message_id?: string;
};
const enqueuedItem = runtime.queueRuntime.enqueue({
kind: "message",
source: "user",
content: (firstPayload as MessageCreate).content,
content: userPayload.content,
clientMessageId:
userPayload.client_message_id ?? `cm-submit-${crypto.randomUUID()}`,
} as Parameters<typeof runtime.queueRuntime.enqueue>[0]);
enqueuedQueueItemId = enqueuedItem?.id ?? null;
// Emit blocked on state transition when turns are already queued.
// pendingTurns is incremented synchronously (below) before .then(),
// so a second arrival always sees the correct count.
@@ -1497,9 +1625,50 @@ async function connectWithRetry(
return;
}
// Signal dequeue for exactly this one turn (one message per chain cb)
if (isUserMessage) {
runtime.queueRuntime.consumeItems(1);
let messageForTurn = parsed;
let dequeuedBatchId: string | null = null;
if (isUserMessage && enqueuedQueueItemId) {
if (runtime.coalescedSkipQueueItemIds.has(enqueuedQueueItemId)) {
runtime.coalescedSkipQueueItemIds.delete(enqueuedQueueItemId);
runtime.pendingTurns--;
if (runtime.pendingTurns === 0) {
runtime.queueRuntime.resetBlockedState();
}
return;
}
const dequeuedBatch = runtime.queueRuntime.tryDequeue(null);
if (!dequeuedBatch) {
runtime.pendingTurns--;
if (runtime.pendingTurns === 0) {
runtime.queueRuntime.resetBlockedState();
}
return;
}
dequeuedBatchId = dequeuedBatch.batchId;
for (const item of dequeuedBatch.items) {
if (item.id !== enqueuedQueueItemId) {
runtime.coalescedSkipQueueItemIds.add(item.id);
}
}
const mergedContent = mergeDequeuedBatchContent(
dequeuedBatch.items,
);
if (mergedContent !== null) {
const firstMessage = parsed.messages.at(0);
if (firstMessage && "content" in firstMessage) {
const mergedFirstMessage = {
...firstMessage,
content: mergedContent,
};
messageForTurn = {
...parsed,
messages: [mergedFirstMessage, ...parsed.messages.slice(1)],
};
}
}
}
// onStatusChange("receiving") is inside try so that any throw
@@ -1507,11 +1676,12 @@ async function connectWithRetry(
try {
opts.onStatusChange?.("receiving", opts.connectionId);
await handleIncomingMessage(
parsed,
messageForTurn,
socket,
runtime,
opts.onStatusChange,
opts.connectionId,
dequeuedBatchId ?? `batch-direct-${crypto.randomUUID()}`,
);
opts.onStatusChange?.("idle", opts.connectionId);
} finally {
@@ -1544,6 +1714,7 @@ async function connectWithRetry(
// Single authoritative queue_cleared emission for all close paths
// (intentional and unintentional). Must fire before early returns.
runtime.coalescedSkipQueueItemIds.clear();
runtime.queueRuntime.clear("shutdown");
if (process.env.DEBUG) {
@@ -1613,6 +1784,7 @@ async function handleIncomingMessage(
connectionId: string,
) => void,
connectionId?: string,
dequeuedBatchId: string = `batch-direct-${crypto.randomUUID()}`,
): Promise<void> {
// Hoist identifiers and tracking state so they're available in catch for error-result
const agentId = msg.agentId;
@@ -1761,6 +1933,7 @@ async function handleIncomingMessage(
sendClientMessage(socket, {
type: "run_started",
runId: maybeRunId,
batch_id: dequeuedBatchId,
});
}
}
@@ -1916,6 +2089,10 @@ async function handleIncomingMessage(
break;
}
// Persist origin correlation for this approval wait so a later recovery
// can continue the same dequeued-turn run block.
rememberPendingApprovalBatchIds(runtime, approvals, dequeuedBatchId);
// Classify approvals (auto-allow, auto-deny, needs user input)
// Don't treat "ask" as deny - cloud UI can handle approvals
// Interactive tools (AskUserQuestion, EnterPlanMode, ExitPlanMode) always need user input
@@ -2078,6 +2255,10 @@ async function handleIncomingMessage(
abortSignal: runtime.activeAbortController.signal,
},
);
clearPendingApprovalBatchIds(
runtime,
decisions.map((decision) => decision.approval),
);
// Create fresh approval stream for next iteration
stream = await sendMessageStreamWithRetry(
@@ -2198,4 +2379,7 @@ export const __listenClientTestUtils = {
createRuntime,
stopRuntime,
emitToWS,
rememberPendingApprovalBatchIds,
resolvePendingApprovalBatchId,
clearPendingApprovalBatchIds,
};