fix(listen): enforce strict queue/run correlation in recovery

This commit is contained in:
cpacker
2026-03-01 00:06:37 -08:00
parent ac4621d91d
commit b4910cd410
7 changed files with 342 additions and 38 deletions

View File

@@ -32,6 +32,7 @@ import { drainStreamWithResume } from "../cli/helpers/stream";
import { computeDiffPreviews } from "../helpers/diffPreview";
import { permissionMode } from "../permissions/mode";
import { type QueueItem, QueueRuntime } from "../queue/queueRuntime";
import { mergeQueuedTurnInput } from "../queue/turnQueueRuntime";
import { settingsManager } from "../settings-manager";
import { isInteractiveApprovalTool } from "../tools/interactivePolicy";
import { loadTools } from "../tools/manager";
@@ -100,7 +101,9 @@ interface IncomingMessage {
type: "message";
agentId?: string;
conversationId?: string;
messages: Array<MessageCreate | ApprovalCreate>;
messages: Array<
(MessageCreate & { client_message_id?: string }) | ApprovalCreate
>;
/** Cloud sets this when it supports can_use_tool / control_response protocol. */
supportsControlResponse?: boolean;
}
@@ -116,6 +119,7 @@ interface ResultMessage {
interface RunStartedMessage {
type: "run_started";
runId: string;
batch_id: string;
event_seq?: number;
session_id?: string;
}
@@ -194,6 +198,7 @@ interface StateResponseMessage {
pending_turns: number;
items: Array<{
id: string;
client_message_id: string;
kind: string;
source: string;
content: unknown;
@@ -256,6 +261,11 @@ type ListenerRuntime = {
cancelRequested: boolean;
/** Queue lifecycle tracking — parallel tracking layer, does not affect message processing. */
queueRuntime: QueueRuntime;
/**
* Queue item IDs that were coalesced into an earlier dequeued batch.
* Their already-scheduled promise-chain callbacks should no-op.
*/
coalescedSkipQueueItemIds: Set<string>;
/** Count of turns currently queued or in-flight in the promise chain. Incremented
* synchronously on message arrival (before .then()) to avoid async scheduling races. */
pendingTurns: number;
@@ -263,6 +273,11 @@ type ListenerRuntime = {
onWsEvent?: StartListenerOptions["onWsEvent"];
/** Prevent duplicate concurrent pending-approval recovery passes. */
isRecoveringApprovals: boolean;
/**
* Correlates pending approval tool_call_ids to the originating dequeued batch.
* Used to preserve run attachment continuity across approval recovery.
*/
pendingApprovalBatchByToolCallId: Map<string, string>;
};
type ApprovalSlot =
@@ -335,6 +350,8 @@ function createRuntime(): ListenerRuntime {
activeAbortController: null,
cancelRequested: false,
isRecoveringApprovals: false,
pendingApprovalBatchByToolCallId: new Map<string, string>(),
coalescedSkipQueueItemIds: new Set<string>(),
pendingTurns: 0,
// queueRuntime assigned below — needs runtime ref in callbacks
queueRuntime: null as unknown as QueueRuntime,
@@ -348,6 +365,7 @@ function createRuntime(): ListenerRuntime {
type: "queue_item_enqueued",
id: item.id,
item_id: item.id,
client_message_id: item.clientMessageId ?? `cm-${item.id}`,
source: item.source,
kind: item.kind,
content,
@@ -430,6 +448,51 @@ function clearActiveRunState(runtime: ListenerRuntime): void {
runtime.activeAbortController = null;
}
function rememberPendingApprovalBatchIds(
runtime: ListenerRuntime,
pendingApprovals: Array<{ toolCallId: string }>,
batchId: string,
): void {
for (const approval of pendingApprovals) {
if (approval.toolCallId) {
runtime.pendingApprovalBatchByToolCallId.set(
approval.toolCallId,
batchId,
);
}
}
}
function resolvePendingApprovalBatchId(
runtime: ListenerRuntime,
pendingApprovals: Array<{ toolCallId: string }>,
): string | null {
const batchIds = new Set<string>();
for (const approval of pendingApprovals) {
const batchId = runtime.pendingApprovalBatchByToolCallId.get(
approval.toolCallId,
);
// Fail closed: every pending approval must have an originating batch mapping.
if (!batchId) {
return null;
}
batchIds.add(batchId);
}
if (batchIds.size !== 1) {
return null;
}
return batchIds.values().next().value ?? null;
}
function clearPendingApprovalBatchIds(
runtime: ListenerRuntime,
approvals: Array<{ toolCallId: string }>,
): void {
for (const approval of approvals) {
runtime.pendingApprovalBatchByToolCallId.delete(approval.toolCallId);
}
}
function stopRuntime(
runtime: ListenerRuntime,
suppressCallbacks: boolean,
@@ -444,6 +507,7 @@ function stopRuntime(
}
clearRuntimeTimers(runtime);
rejectPendingApprovalResolvers(runtime, "Listener runtime stopped");
runtime.pendingApprovalBatchByToolCallId.clear();
if (!runtime.socket) {
return;
@@ -537,12 +601,45 @@ function getQueueItemContent(item: QueueItem): unknown {
return item.kind === "message" ? item.content : item.text;
}
function mergeDequeuedBatchContent(
items: QueueItem[],
): MessageCreate["content"] | null {
const queuedInputs: Array<
| { kind: "user"; content: MessageCreate["content"] }
| {
kind: "task_notification";
text: string;
}
> = [];
for (const item of items) {
if (item.kind === "message") {
queuedInputs.push({
kind: "user",
content: item.content,
});
continue;
}
if (item.kind === "task_notification") {
queuedInputs.push({
kind: "task_notification",
text: item.text,
});
}
}
return mergeQueuedTurnInput(queuedInputs, {
normalizeUserContent: (content) => content,
});
}
function buildStateResponse(
runtime: ListenerRuntime,
stateSeq: number,
): StateResponseMessage {
const queueItems = runtime.queueRuntime.items.map((item) => ({
id: item.id,
client_message_id: item.clientMessageId ?? `cm-${item.id}`,
kind: item.kind,
source: item.source,
content: getQueueItemContent(item),
@@ -1011,6 +1108,23 @@ async function recoverPendingApprovals(
return;
}
const recoveryBatchId = resolvePendingApprovalBatchId(
runtime,
pendingApprovals,
);
if (!recoveryBatchId) {
emitToWS(socket, {
type: "error",
message:
"Unable to recover pending approvals without originating batch correlation",
stop_reason: "error",
session_id: runtime.sessionId,
uuid: `error-${crypto.randomUUID()}`,
});
runtime.lastStopReason = "requires_approval";
return;
}
type Decision =
| {
type: "approve";
@@ -1154,6 +1268,10 @@ async function recoverPendingApprovals(
}
const executionResults = await executeApprovalBatch(decisions);
clearPendingApprovalBatchIds(
runtime,
decisions.map((decision) => decision.approval),
);
await handleIncomingMessage(
{
@@ -1170,6 +1288,9 @@ async function recoverPendingApprovals(
},
socket,
runtime,
undefined,
undefined,
recoveryBatchId,
);
} finally {
runtime.isRecoveringApprovals = false;
@@ -1474,12 +1595,19 @@ async function connectWithRetry(
const firstPayload = parsed.messages.at(0);
const isUserMessage =
firstPayload !== undefined && "content" in firstPayload;
let enqueuedQueueItemId: string | null = null;
if (isUserMessage) {
runtime.queueRuntime.enqueue({
const userPayload = firstPayload as MessageCreate & {
client_message_id?: string;
};
const enqueuedItem = runtime.queueRuntime.enqueue({
kind: "message",
source: "user",
content: (firstPayload as MessageCreate).content,
content: userPayload.content,
clientMessageId:
userPayload.client_message_id ?? `cm-submit-${crypto.randomUUID()}`,
} as Parameters<typeof runtime.queueRuntime.enqueue>[0]);
enqueuedQueueItemId = enqueuedItem?.id ?? null;
// Emit blocked on state transition when turns are already queued.
// pendingTurns is incremented synchronously (below) before .then(),
// so a second arrival always sees the correct count.
@@ -1497,9 +1625,50 @@ async function connectWithRetry(
return;
}
// Signal dequeue for exactly this one turn (one message per chain cb)
if (isUserMessage) {
runtime.queueRuntime.consumeItems(1);
let messageForTurn = parsed;
let dequeuedBatchId: string | null = null;
if (isUserMessage && enqueuedQueueItemId) {
if (runtime.coalescedSkipQueueItemIds.has(enqueuedQueueItemId)) {
runtime.coalescedSkipQueueItemIds.delete(enqueuedQueueItemId);
runtime.pendingTurns--;
if (runtime.pendingTurns === 0) {
runtime.queueRuntime.resetBlockedState();
}
return;
}
const dequeuedBatch = runtime.queueRuntime.tryDequeue(null);
if (!dequeuedBatch) {
runtime.pendingTurns--;
if (runtime.pendingTurns === 0) {
runtime.queueRuntime.resetBlockedState();
}
return;
}
dequeuedBatchId = dequeuedBatch.batchId;
for (const item of dequeuedBatch.items) {
if (item.id !== enqueuedQueueItemId) {
runtime.coalescedSkipQueueItemIds.add(item.id);
}
}
const mergedContent = mergeDequeuedBatchContent(
dequeuedBatch.items,
);
if (mergedContent !== null) {
const firstMessage = parsed.messages.at(0);
if (firstMessage && "content" in firstMessage) {
const mergedFirstMessage = {
...firstMessage,
content: mergedContent,
};
messageForTurn = {
...parsed,
messages: [mergedFirstMessage, ...parsed.messages.slice(1)],
};
}
}
}
// onStatusChange("receiving") is inside try so that any throw
@@ -1507,11 +1676,12 @@ async function connectWithRetry(
try {
opts.onStatusChange?.("receiving", opts.connectionId);
await handleIncomingMessage(
parsed,
messageForTurn,
socket,
runtime,
opts.onStatusChange,
opts.connectionId,
dequeuedBatchId ?? `batch-direct-${crypto.randomUUID()}`,
);
opts.onStatusChange?.("idle", opts.connectionId);
} finally {
@@ -1544,6 +1714,7 @@ async function connectWithRetry(
// Single authoritative queue_cleared emission for all close paths
// (intentional and unintentional). Must fire before early returns.
runtime.coalescedSkipQueueItemIds.clear();
runtime.queueRuntime.clear("shutdown");
if (process.env.DEBUG) {
@@ -1613,6 +1784,7 @@ async function handleIncomingMessage(
connectionId: string,
) => void,
connectionId?: string,
dequeuedBatchId: string = `batch-direct-${crypto.randomUUID()}`,
): Promise<void> {
// Hoist identifiers and tracking state so they're available in catch for error-result
const agentId = msg.agentId;
@@ -1761,6 +1933,7 @@ async function handleIncomingMessage(
sendClientMessage(socket, {
type: "run_started",
runId: maybeRunId,
batch_id: dequeuedBatchId,
});
}
}
@@ -1916,6 +2089,10 @@ async function handleIncomingMessage(
break;
}
// Persist origin correlation for this approval wait so a later recovery
// can continue the same dequeued-turn run block.
rememberPendingApprovalBatchIds(runtime, approvals, dequeuedBatchId);
// Classify approvals (auto-allow, auto-deny, needs user input)
// Don't treat "ask" as deny - cloud UI can handle approvals
// Interactive tools (AskUserQuestion, EnterPlanMode, ExitPlanMode) always need user input
@@ -2078,6 +2255,10 @@ async function handleIncomingMessage(
abortSignal: runtime.activeAbortController.signal,
},
);
clearPendingApprovalBatchIds(
runtime,
decisions.map((decision) => decision.approval),
);
// Create fresh approval stream for next iteration
stream = await sendMessageStreamWithRetry(
@@ -2198,4 +2379,7 @@ export const __listenClientTestUtils = {
createRuntime,
stopRuntime,
emitToWS,
rememberPendingApprovalBatchIds,
resolvePendingApprovalBatchId,
clearPendingApprovalBatchIds,
};