fix(listen): enforce strict queue/run correlation in recovery
This commit is contained in:
@@ -32,6 +32,7 @@ import { drainStreamWithResume } from "../cli/helpers/stream";
|
||||
import { computeDiffPreviews } from "../helpers/diffPreview";
|
||||
import { permissionMode } from "../permissions/mode";
|
||||
import { type QueueItem, QueueRuntime } from "../queue/queueRuntime";
|
||||
import { mergeQueuedTurnInput } from "../queue/turnQueueRuntime";
|
||||
import { settingsManager } from "../settings-manager";
|
||||
import { isInteractiveApprovalTool } from "../tools/interactivePolicy";
|
||||
import { loadTools } from "../tools/manager";
|
||||
@@ -100,7 +101,9 @@ interface IncomingMessage {
|
||||
type: "message";
|
||||
agentId?: string;
|
||||
conversationId?: string;
|
||||
messages: Array<MessageCreate | ApprovalCreate>;
|
||||
messages: Array<
|
||||
(MessageCreate & { client_message_id?: string }) | ApprovalCreate
|
||||
>;
|
||||
/** Cloud sets this when it supports can_use_tool / control_response protocol. */
|
||||
supportsControlResponse?: boolean;
|
||||
}
|
||||
@@ -116,6 +119,7 @@ interface ResultMessage {
|
||||
interface RunStartedMessage {
|
||||
type: "run_started";
|
||||
runId: string;
|
||||
batch_id: string;
|
||||
event_seq?: number;
|
||||
session_id?: string;
|
||||
}
|
||||
@@ -194,6 +198,7 @@ interface StateResponseMessage {
|
||||
pending_turns: number;
|
||||
items: Array<{
|
||||
id: string;
|
||||
client_message_id: string;
|
||||
kind: string;
|
||||
source: string;
|
||||
content: unknown;
|
||||
@@ -256,6 +261,11 @@ type ListenerRuntime = {
|
||||
cancelRequested: boolean;
|
||||
/** Queue lifecycle tracking — parallel tracking layer, does not affect message processing. */
|
||||
queueRuntime: QueueRuntime;
|
||||
/**
|
||||
* Queue item IDs that were coalesced into an earlier dequeued batch.
|
||||
* Their already-scheduled promise-chain callbacks should no-op.
|
||||
*/
|
||||
coalescedSkipQueueItemIds: Set<string>;
|
||||
/** Count of turns currently queued or in-flight in the promise chain. Incremented
|
||||
* synchronously on message arrival (before .then()) to avoid async scheduling races. */
|
||||
pendingTurns: number;
|
||||
@@ -263,6 +273,11 @@ type ListenerRuntime = {
|
||||
onWsEvent?: StartListenerOptions["onWsEvent"];
|
||||
/** Prevent duplicate concurrent pending-approval recovery passes. */
|
||||
isRecoveringApprovals: boolean;
|
||||
/**
|
||||
* Correlates pending approval tool_call_ids to the originating dequeued batch.
|
||||
* Used to preserve run attachment continuity across approval recovery.
|
||||
*/
|
||||
pendingApprovalBatchByToolCallId: Map<string, string>;
|
||||
};
|
||||
|
||||
type ApprovalSlot =
|
||||
@@ -335,6 +350,8 @@ function createRuntime(): ListenerRuntime {
|
||||
activeAbortController: null,
|
||||
cancelRequested: false,
|
||||
isRecoveringApprovals: false,
|
||||
pendingApprovalBatchByToolCallId: new Map<string, string>(),
|
||||
coalescedSkipQueueItemIds: new Set<string>(),
|
||||
pendingTurns: 0,
|
||||
// queueRuntime assigned below — needs runtime ref in callbacks
|
||||
queueRuntime: null as unknown as QueueRuntime,
|
||||
@@ -348,6 +365,7 @@ function createRuntime(): ListenerRuntime {
|
||||
type: "queue_item_enqueued",
|
||||
id: item.id,
|
||||
item_id: item.id,
|
||||
client_message_id: item.clientMessageId ?? `cm-${item.id}`,
|
||||
source: item.source,
|
||||
kind: item.kind,
|
||||
content,
|
||||
@@ -430,6 +448,51 @@ function clearActiveRunState(runtime: ListenerRuntime): void {
|
||||
runtime.activeAbortController = null;
|
||||
}
|
||||
|
||||
function rememberPendingApprovalBatchIds(
|
||||
runtime: ListenerRuntime,
|
||||
pendingApprovals: Array<{ toolCallId: string }>,
|
||||
batchId: string,
|
||||
): void {
|
||||
for (const approval of pendingApprovals) {
|
||||
if (approval.toolCallId) {
|
||||
runtime.pendingApprovalBatchByToolCallId.set(
|
||||
approval.toolCallId,
|
||||
batchId,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function resolvePendingApprovalBatchId(
|
||||
runtime: ListenerRuntime,
|
||||
pendingApprovals: Array<{ toolCallId: string }>,
|
||||
): string | null {
|
||||
const batchIds = new Set<string>();
|
||||
for (const approval of pendingApprovals) {
|
||||
const batchId = runtime.pendingApprovalBatchByToolCallId.get(
|
||||
approval.toolCallId,
|
||||
);
|
||||
// Fail closed: every pending approval must have an originating batch mapping.
|
||||
if (!batchId) {
|
||||
return null;
|
||||
}
|
||||
batchIds.add(batchId);
|
||||
}
|
||||
if (batchIds.size !== 1) {
|
||||
return null;
|
||||
}
|
||||
return batchIds.values().next().value ?? null;
|
||||
}
|
||||
|
||||
function clearPendingApprovalBatchIds(
|
||||
runtime: ListenerRuntime,
|
||||
approvals: Array<{ toolCallId: string }>,
|
||||
): void {
|
||||
for (const approval of approvals) {
|
||||
runtime.pendingApprovalBatchByToolCallId.delete(approval.toolCallId);
|
||||
}
|
||||
}
|
||||
|
||||
function stopRuntime(
|
||||
runtime: ListenerRuntime,
|
||||
suppressCallbacks: boolean,
|
||||
@@ -444,6 +507,7 @@ function stopRuntime(
|
||||
}
|
||||
clearRuntimeTimers(runtime);
|
||||
rejectPendingApprovalResolvers(runtime, "Listener runtime stopped");
|
||||
runtime.pendingApprovalBatchByToolCallId.clear();
|
||||
|
||||
if (!runtime.socket) {
|
||||
return;
|
||||
@@ -537,12 +601,45 @@ function getQueueItemContent(item: QueueItem): unknown {
|
||||
return item.kind === "message" ? item.content : item.text;
|
||||
}
|
||||
|
||||
function mergeDequeuedBatchContent(
|
||||
items: QueueItem[],
|
||||
): MessageCreate["content"] | null {
|
||||
const queuedInputs: Array<
|
||||
| { kind: "user"; content: MessageCreate["content"] }
|
||||
| {
|
||||
kind: "task_notification";
|
||||
text: string;
|
||||
}
|
||||
> = [];
|
||||
|
||||
for (const item of items) {
|
||||
if (item.kind === "message") {
|
||||
queuedInputs.push({
|
||||
kind: "user",
|
||||
content: item.content,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (item.kind === "task_notification") {
|
||||
queuedInputs.push({
|
||||
kind: "task_notification",
|
||||
text: item.text,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return mergeQueuedTurnInput(queuedInputs, {
|
||||
normalizeUserContent: (content) => content,
|
||||
});
|
||||
}
|
||||
|
||||
function buildStateResponse(
|
||||
runtime: ListenerRuntime,
|
||||
stateSeq: number,
|
||||
): StateResponseMessage {
|
||||
const queueItems = runtime.queueRuntime.items.map((item) => ({
|
||||
id: item.id,
|
||||
client_message_id: item.clientMessageId ?? `cm-${item.id}`,
|
||||
kind: item.kind,
|
||||
source: item.source,
|
||||
content: getQueueItemContent(item),
|
||||
@@ -1011,6 +1108,23 @@ async function recoverPendingApprovals(
|
||||
return;
|
||||
}
|
||||
|
||||
const recoveryBatchId = resolvePendingApprovalBatchId(
|
||||
runtime,
|
||||
pendingApprovals,
|
||||
);
|
||||
if (!recoveryBatchId) {
|
||||
emitToWS(socket, {
|
||||
type: "error",
|
||||
message:
|
||||
"Unable to recover pending approvals without originating batch correlation",
|
||||
stop_reason: "error",
|
||||
session_id: runtime.sessionId,
|
||||
uuid: `error-${crypto.randomUUID()}`,
|
||||
});
|
||||
runtime.lastStopReason = "requires_approval";
|
||||
return;
|
||||
}
|
||||
|
||||
type Decision =
|
||||
| {
|
||||
type: "approve";
|
||||
@@ -1154,6 +1268,10 @@ async function recoverPendingApprovals(
|
||||
}
|
||||
|
||||
const executionResults = await executeApprovalBatch(decisions);
|
||||
clearPendingApprovalBatchIds(
|
||||
runtime,
|
||||
decisions.map((decision) => decision.approval),
|
||||
);
|
||||
|
||||
await handleIncomingMessage(
|
||||
{
|
||||
@@ -1170,6 +1288,9 @@ async function recoverPendingApprovals(
|
||||
},
|
||||
socket,
|
||||
runtime,
|
||||
undefined,
|
||||
undefined,
|
||||
recoveryBatchId,
|
||||
);
|
||||
} finally {
|
||||
runtime.isRecoveringApprovals = false;
|
||||
@@ -1474,12 +1595,19 @@ async function connectWithRetry(
|
||||
const firstPayload = parsed.messages.at(0);
|
||||
const isUserMessage =
|
||||
firstPayload !== undefined && "content" in firstPayload;
|
||||
let enqueuedQueueItemId: string | null = null;
|
||||
if (isUserMessage) {
|
||||
runtime.queueRuntime.enqueue({
|
||||
const userPayload = firstPayload as MessageCreate & {
|
||||
client_message_id?: string;
|
||||
};
|
||||
const enqueuedItem = runtime.queueRuntime.enqueue({
|
||||
kind: "message",
|
||||
source: "user",
|
||||
content: (firstPayload as MessageCreate).content,
|
||||
content: userPayload.content,
|
||||
clientMessageId:
|
||||
userPayload.client_message_id ?? `cm-submit-${crypto.randomUUID()}`,
|
||||
} as Parameters<typeof runtime.queueRuntime.enqueue>[0]);
|
||||
enqueuedQueueItemId = enqueuedItem?.id ?? null;
|
||||
// Emit blocked on state transition when turns are already queued.
|
||||
// pendingTurns is incremented synchronously (below) before .then(),
|
||||
// so a second arrival always sees the correct count.
|
||||
@@ -1497,9 +1625,50 @@ async function connectWithRetry(
|
||||
return;
|
||||
}
|
||||
|
||||
// Signal dequeue for exactly this one turn (one message per chain cb)
|
||||
if (isUserMessage) {
|
||||
runtime.queueRuntime.consumeItems(1);
|
||||
let messageForTurn = parsed;
|
||||
let dequeuedBatchId: string | null = null;
|
||||
if (isUserMessage && enqueuedQueueItemId) {
|
||||
if (runtime.coalescedSkipQueueItemIds.has(enqueuedQueueItemId)) {
|
||||
runtime.coalescedSkipQueueItemIds.delete(enqueuedQueueItemId);
|
||||
runtime.pendingTurns--;
|
||||
if (runtime.pendingTurns === 0) {
|
||||
runtime.queueRuntime.resetBlockedState();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const dequeuedBatch = runtime.queueRuntime.tryDequeue(null);
|
||||
if (!dequeuedBatch) {
|
||||
runtime.pendingTurns--;
|
||||
if (runtime.pendingTurns === 0) {
|
||||
runtime.queueRuntime.resetBlockedState();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
dequeuedBatchId = dequeuedBatch.batchId;
|
||||
for (const item of dequeuedBatch.items) {
|
||||
if (item.id !== enqueuedQueueItemId) {
|
||||
runtime.coalescedSkipQueueItemIds.add(item.id);
|
||||
}
|
||||
}
|
||||
|
||||
const mergedContent = mergeDequeuedBatchContent(
|
||||
dequeuedBatch.items,
|
||||
);
|
||||
if (mergedContent !== null) {
|
||||
const firstMessage = parsed.messages.at(0);
|
||||
if (firstMessage && "content" in firstMessage) {
|
||||
const mergedFirstMessage = {
|
||||
...firstMessage,
|
||||
content: mergedContent,
|
||||
};
|
||||
messageForTurn = {
|
||||
...parsed,
|
||||
messages: [mergedFirstMessage, ...parsed.messages.slice(1)],
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onStatusChange("receiving") is inside try so that any throw
|
||||
@@ -1507,11 +1676,12 @@ async function connectWithRetry(
|
||||
try {
|
||||
opts.onStatusChange?.("receiving", opts.connectionId);
|
||||
await handleIncomingMessage(
|
||||
parsed,
|
||||
messageForTurn,
|
||||
socket,
|
||||
runtime,
|
||||
opts.onStatusChange,
|
||||
opts.connectionId,
|
||||
dequeuedBatchId ?? `batch-direct-${crypto.randomUUID()}`,
|
||||
);
|
||||
opts.onStatusChange?.("idle", opts.connectionId);
|
||||
} finally {
|
||||
@@ -1544,6 +1714,7 @@ async function connectWithRetry(
|
||||
|
||||
// Single authoritative queue_cleared emission for all close paths
|
||||
// (intentional and unintentional). Must fire before early returns.
|
||||
runtime.coalescedSkipQueueItemIds.clear();
|
||||
runtime.queueRuntime.clear("shutdown");
|
||||
|
||||
if (process.env.DEBUG) {
|
||||
@@ -1613,6 +1784,7 @@ async function handleIncomingMessage(
|
||||
connectionId: string,
|
||||
) => void,
|
||||
connectionId?: string,
|
||||
dequeuedBatchId: string = `batch-direct-${crypto.randomUUID()}`,
|
||||
): Promise<void> {
|
||||
// Hoist identifiers and tracking state so they're available in catch for error-result
|
||||
const agentId = msg.agentId;
|
||||
@@ -1761,6 +1933,7 @@ async function handleIncomingMessage(
|
||||
sendClientMessage(socket, {
|
||||
type: "run_started",
|
||||
runId: maybeRunId,
|
||||
batch_id: dequeuedBatchId,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1916,6 +2089,10 @@ async function handleIncomingMessage(
|
||||
break;
|
||||
}
|
||||
|
||||
// Persist origin correlation for this approval wait so a later recovery
|
||||
// can continue the same dequeued-turn run block.
|
||||
rememberPendingApprovalBatchIds(runtime, approvals, dequeuedBatchId);
|
||||
|
||||
// Classify approvals (auto-allow, auto-deny, needs user input)
|
||||
// Don't treat "ask" as deny - cloud UI can handle approvals
|
||||
// Interactive tools (AskUserQuestion, EnterPlanMode, ExitPlanMode) always need user input
|
||||
@@ -2078,6 +2255,10 @@ async function handleIncomingMessage(
|
||||
abortSignal: runtime.activeAbortController.signal,
|
||||
},
|
||||
);
|
||||
clearPendingApprovalBatchIds(
|
||||
runtime,
|
||||
decisions.map((decision) => decision.approval),
|
||||
);
|
||||
|
||||
// Create fresh approval stream for next iteration
|
||||
stream = await sendMessageStreamWithRetry(
|
||||
@@ -2198,4 +2379,7 @@ export const __listenClientTestUtils = {
|
||||
createRuntime,
|
||||
stopRuntime,
|
||||
emitToWS,
|
||||
rememberPendingApprovalBatchIds,
|
||||
resolvePendingApprovalBatchId,
|
||||
clearPendingApprovalBatchIds,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user