From 6cbba40ff975b04ba4aba4760854db2239611092 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 19 Mar 2026 10:11:37 -0700 Subject: [PATCH] fix(listener): drain queued turns during approval reentry (#1448) Co-authored-by: Letta Code --- .../listen-client-concurrency.test.ts | 325 +++++++++++++++++- .../websocket/listen-client-protocol.test.ts | 39 +++ .../websocket/listenerQueueAdapter.test.ts | 55 ++- src/types/protocol.ts | 2 +- src/websocket/helpers/listenerQueueAdapter.ts | 15 +- src/websocket/listener/approval.ts | 10 +- src/websocket/listener/client.ts | 174 +++++++--- src/websocket/listener/queue.ts | 68 +++- src/websocket/listener/recovery.ts | 32 +- src/websocket/listener/send.ts | 37 +- src/websocket/listener/turn-approval.ts | 24 +- src/websocket/listener/turn.ts | 4 +- 12 files changed, 687 insertions(+), 98 deletions(-) diff --git a/src/tests/websocket/listen-client-concurrency.test.ts b/src/tests/websocket/listen-client-concurrency.test.ts index 71f285d..a323419 100644 --- a/src/tests/websocket/listen-client-concurrency.test.ts +++ b/src/tests/websocket/listen-client-concurrency.test.ts @@ -1,7 +1,12 @@ import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; import WebSocket from "ws"; +import type { ResumeData } from "../../agent/check-approval"; import { permissionMode } from "../../permissions/mode"; -import type { MessageQueueItem } from "../../queue/queueRuntime"; +import type { + MessageQueueItem, + TaskNotificationQueueItem, +} from "../../queue/queueRuntime"; +import type { IncomingMessage } from "../../websocket/listener/types"; type MockStream = { conversationId: string; @@ -53,12 +58,29 @@ const drainStreamWithResumeMock = mock( return defaultDrainResult; }, ); +const retrieveAgentMock = mock(async (agentId: string) => ({ id: agentId })); const cancelConversationMock = mock(async (_conversationId: string) => {}); const getClientMock = mock(async () => ({ + agents: { + retrieve: retrieveAgentMock, + }, conversations: { cancel: cancelConversationMock, }, })); +const getResumeDataMock = mock( + async (): Promise => ({ + pendingApproval: null, + pendingApprovals: [], + messageHistory: [], + }), +); +const classifyApprovalsMock = mock(async () => ({ + autoAllowed: [], + autoDenied: [], + needsUserInput: [], +})); +const executeApprovalBatchMock = mock(async () => []); const fetchRunErrorDetailMock = mock(async () => null); const realStreamModule = await import("../../cli/helpers/stream"); @@ -100,6 +122,14 @@ mock.module("../../agent/client", () => ({ consumeLastSDKDiagnostic: () => null, })); +mock.module("../../cli/helpers/approvalClassification", () => ({ + classifyApprovals: classifyApprovalsMock, +})); + +mock.module("../../agent/approval-execution", () => ({ + executeApprovalBatch: executeApprovalBatchMock, +})); + mock.module("../../agent/approval-recovery", () => ({ fetchRunErrorDetail: fetchRunErrorDetailMock, })); @@ -172,6 +202,10 @@ describe("listen-client multi-worker concurrency", () => { getStreamToolContextIdMock.mockClear(); drainStreamWithResumeMock.mockClear(); getClientMock.mockClear(); + retrieveAgentMock.mockClear(); + getResumeDataMock.mockClear(); + classifyApprovalsMock.mockClear(); + executeApprovalBatchMock.mockClear(); cancelConversationMock.mockClear(); fetchRunErrorDetailMock.mockClear(); drainHandlers.clear(); @@ -573,6 +607,205 @@ describe("listen-client multi-worker concurrency", () => { expect(runtimeB.queuedMessagesByItemId.size).toBe(0); }); + test("consumeQueuedTurn only drains the next same-scope queued turn batch", () => { + const runtime = __listenClientTestUtils.createRuntime(); + const messageInput = { + kind: "message", + source: "user", + content: "queued user", + clientMessageId: "cm-user", + agentId: "agent-1", + conversationId: "conv-1", + } satisfies Omit; + const messageItem = runtime.queueRuntime.enqueue(messageInput); + + if (!messageItem) { + throw new Error("Expected queued message item"); + } + + runtime.queuedMessagesByItemId.set( + messageItem.id, + makeIncomingMessage("agent-1", "conv-1", "queued user"), + ); + + const taskInput = { + kind: "task_notification", + source: "system", + text: "done", + clientMessageId: "cm-task", + agentId: "agent-1", + conversationId: "conv-1", + } satisfies Omit; + const taskItem = runtime.queueRuntime.enqueue(taskInput); + + if (!taskItem) { + throw new Error("Expected queued task notification item"); + } + + const otherMessageInput = { + kind: "message", + source: "user", + content: "queued other", + clientMessageId: "cm-other", + agentId: "agent-1", + conversationId: "conv-2", + } satisfies Omit; + const otherMessageItem = runtime.queueRuntime.enqueue(otherMessageInput); + + if (!otherMessageItem) { + throw new Error("Expected second queued message item"); + } + + runtime.queuedMessagesByItemId.set( + otherMessageItem.id, + makeIncomingMessage("agent-1", "conv-2", "queued other"), + ); + + const consumed = __listenClientTestUtils.consumeQueuedTurn(runtime); + + expect(consumed).not.toBeNull(); + expect( + consumed?.dequeuedBatch.items.map((item: { id: string }) => item.id), + ).toEqual([messageItem.id, taskItem.id]); + expect(consumed?.queuedTurn.messages).toEqual([ + { + role: "user", + content: [ + { type: "text", text: "queued user" }, + { type: "text", text: "\n" }, + { + type: "text", + text: "done", + }, + ], + }, + ]); + expect(runtime.queueRuntime.length).toBe(1); + expect(runtime.queuedMessagesByItemId.has(otherMessageItem.id)).toBe(true); + }); + + test("resolveStaleApprovals injects queued turns and marks recovery drain as processing", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + runtime.agentId = "agent-1"; + runtime.conversationId = "conv-1"; + runtime.activeWorkingDirectory = "/tmp/project"; + runtime.loopStatus = "WAITING_FOR_API_RESPONSE"; + const socket = new MockSocket(); + const drain = createDeferredDrain(); + drainHandlers.set("conv-1", () => drain.promise); + + const approval = { + toolCallId: "tool-call-1", + toolName: "Write", + toolArgs: '{"file_path":"foo.ts"}', + }; + const approvalResult = { + type: "tool", + tool_call_id: "tool-call-1", + tool_return: "ok", + status: "success", + }; + + getResumeDataMock.mockResolvedValueOnce({ + pendingApproval: approval, + pendingApprovals: [approval], + messageHistory: [], + }); + classifyApprovalsMock.mockResolvedValueOnce({ + autoAllowed: [ + { + approval, + parsedArgs: { file_path: "foo.ts" }, + }, + ], + autoDenied: [], + needsUserInput: [], + } as never); + executeApprovalBatchMock.mockResolvedValueOnce([approvalResult] as never); + + const queuedMessageInput = { + kind: "message", + source: "user", + content: "queued user", + clientMessageId: "cm-stale-user", + agentId: "agent-1", + conversationId: "conv-1", + } satisfies Omit; + const queuedMessageItem = runtime.queueRuntime.enqueue(queuedMessageInput); + if (!queuedMessageItem) { + throw new Error("Expected stale recovery queued message item"); + } + runtime.queuedMessagesByItemId.set( + queuedMessageItem.id, + makeIncomingMessage("agent-1", "conv-1", "queued user"), + ); + + const queuedTaskInput = { + kind: "task_notification", + source: "system", + text: "done", + clientMessageId: "cm-stale-task", + agentId: "agent-1", + conversationId: "conv-1", + } satisfies Omit; + const queuedTaskItem = runtime.queueRuntime.enqueue(queuedTaskInput); + if (!queuedTaskItem) { + throw new Error("Expected stale recovery queued task item"); + } + + const recoveryPromise = __listenClientTestUtils.resolveStaleApprovals( + runtime, + socket as unknown as WebSocket, + new AbortController().signal, + { getResumeData: getResumeDataMock }, + ); + + await waitFor(() => sendMessageStreamMock.mock.calls.length === 1); + await waitFor(() => drainStreamWithResumeMock.mock.calls.length === 1); + + const continuationMessages = sendMessageStreamMock.mock.calls[0]?.[1] as + | Array> + | undefined; + expect(continuationMessages).toHaveLength(2); + expect(continuationMessages?.[0]).toEqual({ + type: "approval", + approvals: [approvalResult], + }); + expect(continuationMessages?.[1]).toEqual({ + role: "user", + content: [ + { type: "text", text: "queued user" }, + { type: "text", text: "\n" }, + { + type: "text", + text: "done", + }, + ], + }); + expect(runtime.loopStatus as string).toBe("PROCESSING_API_RESPONSE"); + expect(runtime.queueRuntime.length).toBe(0); + expect(runtime.queuedMessagesByItemId.size).toBe(0); + expect( + socket.sentPayloads.some( + (payload) => + payload.includes("queued user") && + payload.includes("done"), + ), + ).toBe(true); + + drain.resolve({ + stopReason: "end_turn", + approvals: [], + apiDurationMs: 0, + }); + + await expect(recoveryPromise).resolves.toEqual({ + stopReason: "end_turn", + approvals: [], + apiDurationMs: 0, + }); + }); + test("queue pump status callbacks stay aggregate when another conversation is busy", async () => { const listener = __listenClientTestUtils.createListenerRuntime(); __listenClientTestUtils.setActiveRuntime(listener); @@ -628,4 +861,94 @@ describe("listen-client multi-worker concurrency", () => { expect(listener.conversationRuntimes.has(runtimeB.key)).toBe(false); expect(listener.conversationRuntimes.has(runtimeA.key)).toBe(true); }); + + test("change_device_state command holds queued input until the tracked command completes", async () => { + const listener = __listenClientTestUtils.createListenerRuntime(); + __listenClientTestUtils.setActiveRuntime(listener); + const runtime = __listenClientTestUtils.getOrCreateScopedRuntime( + listener, + "agent-1", + "conv-a", + ); + const socket = new MockSocket(); + const processedTurns: string[] = []; + + const queueInput = { + kind: "message", + source: "user", + content: "queued during command", + clientMessageId: "cm-command", + agentId: "agent-1", + conversationId: "conv-a", + } satisfies Omit; + const item = runtime.queueRuntime.enqueue(queueInput); + if (!item) { + throw new Error("Expected queued item to be created"); + } + runtime.queuedMessagesByItemId.set( + item.id, + makeIncomingMessage("agent-1", "conv-a", "queued during command"), + ); + + let releaseCommand!: () => void; + const commandHold = new Promise((resolve) => { + releaseCommand = resolve; + }); + const processQueuedTurn = async ( + queuedTurn: IncomingMessage, + _dequeuedBatch: unknown, + ) => { + processedTurns.push(queuedTurn.conversationId ?? "default"); + }; + + const commandPromise = __listenClientTestUtils.handleChangeDeviceStateInput( + listener, + { + command: { + type: "change_device_state", + runtime: { agent_id: "agent-1", conversation_id: "conv-a" }, + payload: { cwd: "/tmp/next" }, + }, + socket: socket as unknown as WebSocket, + opts: {}, + processQueuedTurn, + }, + { + handleCwdChange: async () => { + await commandHold; + }, + }, + ); + + await waitFor(() => runtime.loopStatus === "EXECUTING_COMMAND"); + + __listenClientTestUtils.scheduleQueuePump( + runtime, + socket as unknown as WebSocket, + {} as never, + processQueuedTurn, + ); + + await waitFor( + () => + runtime.queueRuntime.length === 1 && + !runtime.queuePumpScheduled && + !runtime.queuePumpActive, + ); + + expect(processedTurns).toEqual([]); + expect(runtime.queueRuntime.length).toBe(1); + expect(runtime.loopStatus).toBe("EXECUTING_COMMAND"); + + releaseCommand(); + await commandPromise; + + await waitFor( + () => processedTurns.length === 1 && runtime.queueRuntime.length === 0, + ); + + expect(processedTurns).toEqual(["conv-a"]); + expect(runtime.loopStatus).toBe("WAITING_ON_INPUT"); + expect(runtime.queuedMessagesByItemId.size).toBe(0); + }); }); diff --git a/src/tests/websocket/listen-client-protocol.test.ts b/src/tests/websocket/listen-client-protocol.test.ts index 561494e..d81d3df 100644 --- a/src/tests/websocket/listen-client-protocol.test.ts +++ b/src/tests/websocket/listen-client-protocol.test.ts @@ -228,6 +228,30 @@ describe("listen-client permission mode scope keys", () => { }); describe("listen-client approval resolver wiring", () => { + test("resolved approvals restore WAITING_ON_INPUT instead of faking processing", () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.OPEN); + runtime.isProcessing = true; + runtime.loopStatus = "WAITING_ON_APPROVAL"; + + void requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + "perm-status", + makeControlRequest("perm-status"), + ).catch(() => {}); + + expect(runtime.loopStatus).toBe("WAITING_ON_APPROVAL"); + + const resolved = resolvePendingApprovalResolver(runtime, { + request_id: "perm-status", + decision: { behavior: "allow" }, + }); + + expect(resolved).toBe(true); + expect(runtime.loopStatus as string).toBe("WAITING_ON_INPUT"); + }); + test("resolves matching pending resolver", async () => { const runtime = __listenClientTestUtils.createRuntime(); const socket = new MockSocket(WebSocket.OPEN); @@ -305,6 +329,21 @@ describe("listen-client approval resolver wiring", () => { await expect(second).rejects.toThrow("socket closed"); }); + test("cleanup resets WAITING_ON_INPUT instead of restoring fake processing", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + runtime.isProcessing = true; + runtime.loopStatus = "WAITING_ON_APPROVAL"; + + const pending = new Promise((resolve, reject) => { + runtime.pendingApprovalResolvers.set("perm-cleanup", { resolve, reject }); + }); + + rejectPendingApprovalResolvers(runtime, "socket closed"); + + expect(runtime.loopStatus as string).toBe("WAITING_ON_INPUT"); + await expect(pending).rejects.toThrow("socket closed"); + }); + test("stopRuntime rejects pending resolvers even when callbacks are suppressed", async () => { const runtime = __listenClientTestUtils.createRuntime(); const pending = new Promise((resolve, reject) => { diff --git a/src/tests/websocket/listenerQueueAdapter.test.ts b/src/tests/websocket/listenerQueueAdapter.test.ts index 8fc2e2c..274447c 100644 --- a/src/tests/websocket/listenerQueueAdapter.test.ts +++ b/src/tests/websocket/listenerQueueAdapter.test.ts @@ -2,6 +2,7 @@ import { describe, expect, test } from "bun:test"; import { getListenerBlockedReason } from "../../websocket/helpers/listenerQueueAdapter"; const allClear = { + loopStatus: "WAITING_ON_INPUT", isProcessing: false, pendingApprovalsLen: 0, cancelRequested: false, @@ -19,11 +20,13 @@ describe("getListenerBlockedReason", () => { ).toBe("pending_approvals"); }); - test("prioritizes interrupt over runtime busy", () => { + test("prioritizes interrupt over approval and streaming phases", () => { expect( getListenerBlockedReason({ ...allClear, cancelRequested: true, + pendingApprovalsLen: 2, + loopStatus: "PROCESSING_API_RESPONSE", isProcessing: true, }), ).toBe("interrupt_in_progress"); @@ -31,13 +34,53 @@ describe("getListenerBlockedReason", () => { test("maps recoveries to runtime busy", () => { expect( - getListenerBlockedReason({ ...allClear, isRecoveringApprovals: true }), + getListenerBlockedReason({ + ...allClear, + isRecoveringApprovals: true, + loopStatus: "EXECUTING_COMMAND", + }), ).toBe("runtime_busy"); }); - test("maps active processing to runtime busy", () => { - expect(getListenerBlockedReason({ ...allClear, isProcessing: true })).toBe( - "runtime_busy", - ); + test("maps waiting-on-approval phase to pending approvals", () => { + expect( + getListenerBlockedReason({ + ...allClear, + loopStatus: "WAITING_ON_APPROVAL", + }), + ).toBe("pending_approvals"); + }); + + test("maps command execution to command_running", () => { + expect( + getListenerBlockedReason({ + ...allClear, + loopStatus: "EXECUTING_COMMAND", + }), + ).toBe("command_running"); + }); + + test.each([ + "SENDING_API_REQUEST", + "RETRYING_API_REQUEST", + "WAITING_FOR_API_RESPONSE", + "PROCESSING_API_RESPONSE", + "EXECUTING_CLIENT_SIDE_TOOL", + ] as const)("maps %s to streaming", (loopStatus) => { + expect( + getListenerBlockedReason({ + ...allClear, + loopStatus, + }), + ).toBe("streaming"); + }); + + test("falls back to runtime busy when processing without a specific phase", () => { + expect( + getListenerBlockedReason({ + ...allClear, + isProcessing: true, + }), + ).toBe("runtime_busy"); }); }); diff --git a/src/types/protocol.ts b/src/types/protocol.ts index e22af37..1a34c94 100644 --- a/src/types/protocol.ts +++ b/src/types/protocol.ts @@ -430,7 +430,7 @@ export interface QueueBatchDequeuedEvent extends MessageEnvelope { /** * Why the queue cannot dequeue right now. - * - streaming: Agent turn is actively streaming + * - streaming: Agent turn is actively running/streaming (request, response, or local tool execution) * - pending_approvals: Waiting for HITL approval decisions * - overlay_open: Plan mode, AskUserQuestion, or other overlay is active * - command_running: Slash command is executing diff --git a/src/websocket/helpers/listenerQueueAdapter.ts b/src/websocket/helpers/listenerQueueAdapter.ts index 8f16bc5..6b382e8 100644 --- a/src/websocket/helpers/listenerQueueAdapter.ts +++ b/src/websocket/helpers/listenerQueueAdapter.ts @@ -1,6 +1,8 @@ import type { QueueBlockedReason } from "../../queue/queueRuntime"; +import type { LoopStatus } from "../../types/protocol_v2"; export type ListenerQueueGatingConditions = { + loopStatus: LoopStatus; isProcessing: boolean; pendingApprovalsLen: number; cancelRequested: boolean; @@ -10,9 +12,20 @@ export type ListenerQueueGatingConditions = { export function getListenerBlockedReason( c: ListenerQueueGatingConditions, ): QueueBlockedReason | null { - if (c.pendingApprovalsLen > 0) return "pending_approvals"; if (c.cancelRequested) return "interrupt_in_progress"; + if (c.pendingApprovalsLen > 0) return "pending_approvals"; if (c.isRecoveringApprovals) return "runtime_busy"; + if (c.loopStatus === "WAITING_ON_APPROVAL") return "pending_approvals"; + if (c.loopStatus === "EXECUTING_COMMAND") return "command_running"; + if ( + c.loopStatus === "SENDING_API_REQUEST" || + c.loopStatus === "RETRYING_API_REQUEST" || + c.loopStatus === "WAITING_FOR_API_RESPONSE" || + c.loopStatus === "PROCESSING_API_RESPONSE" || + c.loopStatus === "EXECUTING_CLIENT_SIDE_TOOL" + ) { + return "streaming"; + } if (c.isProcessing) return "runtime_busy"; return null; } diff --git a/src/websocket/listener/approval.ts b/src/websocket/listener/approval.ts index 5894389..7021b6c 100644 --- a/src/websocket/listener/approval.ts +++ b/src/websocket/listener/approval.ts @@ -197,10 +197,7 @@ export function resolvePendingApprovalResolver( runtime.pendingApprovalResolvers.delete(requestId); runtime.listener.approvalRuntimeKeyByRequestId.delete(requestId); if (runtime.pendingApprovalResolvers.size === 0) { - setLoopStatus( - runtime, - runtime.isProcessing ? "PROCESSING_API_RESPONSE" : "WAITING_ON_INPUT", - ); + setLoopStatus(runtime, "WAITING_ON_INPUT"); } pending.resolve(response); emitLoopStatusIfOpen(runtime.listener, { @@ -229,10 +226,7 @@ export function rejectPendingApprovalResolvers( runtime.listener.approvalRuntimeKeyByRequestId.delete(requestId); } } - setLoopStatus( - runtime, - runtime.isProcessing ? "PROCESSING_API_RESPONSE" : "WAITING_ON_INPUT", - ); + setLoopStatus(runtime, "WAITING_ON_INPUT"); emitLoopStatusIfOpen(runtime.listener, { agent_id: runtime.agentId, conversation_id: runtime.conversationId, diff --git a/src/websocket/listener/client.ts b/src/websocket/listener/client.ts index 0906873..742105a 100644 --- a/src/websocket/listener/client.ts +++ b/src/websocket/listener/client.ts @@ -15,7 +15,10 @@ import { type DequeuedBatch, QueueRuntime } from "../../queue/queueRuntime"; import { createSharedReminderState } from "../../reminders/state"; import { settingsManager } from "../../settings-manager"; import { loadTools } from "../../tools/manager"; -import type { ApprovalResponseBody } from "../../types/protocol_v2"; +import type { + ApprovalResponseBody, + ChangeDeviceStateCommand, +} from "../../types/protocol_v2"; import { isDebugEnabled } from "../../utils/debug"; import { handleTerminalInput, @@ -72,6 +75,7 @@ import { setLoopStatus, } from "./protocol-outbound"; import { + consumeQueuedTurn, getQueueItemScope, getQueueItemsScope, normalizeInboundMessages, @@ -104,7 +108,10 @@ import { normalizeCwdAgentId, resolveRuntimeScope, } from "./scope"; -import { markAwaitingAcceptedApprovalContinuationRunId } from "./send"; +import { + markAwaitingAcceptedApprovalContinuationRunId, + resolveStaleApprovals, +} from "./send"; import { handleIncomingMessage } from "./turn"; import type { ChangeCwdMessage, @@ -350,6 +357,108 @@ async function handleApprovalResponseInput( return false; } +async function handleChangeDeviceStateInput( + listener: ListenerRuntime, + params: { + command: ChangeDeviceStateCommand; + socket: WebSocket; + opts: { + onStatusChange?: StartListenerOptions["onStatusChange"]; + connectionId?: string; + }; + processQueuedTurn: ProcessQueuedTurn; + }, + deps: Partial<{ + getActiveRuntime: typeof getActiveRuntime; + getOrCreateScopedRuntime: typeof getOrCreateScopedRuntime; + getPendingControlRequestCount: typeof getPendingControlRequestCount; + setLoopStatus: typeof setLoopStatus; + handleModeChange: typeof handleModeChange; + handleCwdChange: typeof handleCwdChange; + emitDeviceStatusUpdate: typeof emitDeviceStatusUpdate; + scheduleQueuePump: typeof scheduleQueuePump; + }> = {}, +): Promise { + const resolvedDeps = { + getActiveRuntime, + getOrCreateScopedRuntime, + getPendingControlRequestCount, + setLoopStatus, + handleModeChange, + handleCwdChange, + emitDeviceStatusUpdate, + scheduleQueuePump, + ...deps, + }; + + if ( + listener !== resolvedDeps.getActiveRuntime() || + listener.intentionallyClosed + ) { + return false; + } + + const scope = { + agent_id: + params.command.payload.agent_id ?? + params.command.runtime.agent_id ?? + undefined, + conversation_id: + params.command.payload.conversation_id ?? + params.command.runtime.conversation_id ?? + undefined, + }; + const scopedRuntime = resolvedDeps.getOrCreateScopedRuntime( + listener, + scope.agent_id, + scope.conversation_id, + ); + const shouldTrackCommand = + !scopedRuntime.isProcessing && + resolvedDeps.getPendingControlRequestCount(listener, scope) === 0; + + if (shouldTrackCommand) { + resolvedDeps.setLoopStatus(scopedRuntime, "EXECUTING_COMMAND", scope); + } + + try { + if (params.command.payload.mode) { + resolvedDeps.handleModeChange( + { mode: params.command.payload.mode }, + params.socket, + listener, + scope, + ); + } + + if (params.command.payload.cwd) { + await resolvedDeps.handleCwdChange( + { + agentId: scope.agent_id ?? null, + conversationId: scope.conversation_id ?? null, + cwd: params.command.payload.cwd, + }, + params.socket, + scopedRuntime, + ); + } else if (!params.command.payload.mode) { + resolvedDeps.emitDeviceStatusUpdate(params.socket, listener, scope); + } + } finally { + if (shouldTrackCommand) { + resolvedDeps.setLoopStatus(scopedRuntime, "WAITING_ON_INPUT", scope); + resolvedDeps.scheduleQueuePump( + scopedRuntime, + params.socket, + params.opts as StartListenerOptions, + params.processQueuedTurn, + ); + } + } + + return true; +} + async function handleCwdChange( msg: ChangeCwdMessage, socket: WebSocket, @@ -776,55 +885,15 @@ async function connectWithRetry( } if (parsed.type === "change_device_state") { - if (runtime !== getActiveRuntime() || runtime.intentionallyClosed) { - return; - } - const scope = { - agent_id: - parsed.payload.agent_id ?? parsed.runtime.agent_id ?? undefined, - conversation_id: - parsed.payload.conversation_id ?? - parsed.runtime.conversation_id ?? - undefined, - }; - const scopedRuntime = getOrCreateScopedRuntime( - runtime, - scope.agent_id, - scope.conversation_id, - ); - const shouldTrackCommand = - !scopedRuntime.isProcessing && - getPendingControlRequestCount(runtime, scope) === 0; - if (shouldTrackCommand) { - setLoopStatus(scopedRuntime, "EXECUTING_COMMAND", scope); - } - try { - if (parsed.payload.mode) { - handleModeChange( - { mode: parsed.payload.mode }, - socket, - runtime, - scope, - ); - } - if (parsed.payload.cwd) { - await handleCwdChange( - { - agentId: scope.agent_id ?? null, - conversationId: scope.conversation_id ?? null, - cwd: parsed.payload.cwd, - }, - socket, - scopedRuntime, - ); - } else if (!parsed.payload.mode) { - emitDeviceStatusUpdate(socket, runtime, scope); - } - } finally { - if (shouldTrackCommand) { - setLoopStatus(scopedRuntime, "WAITING_ON_INPUT", scope); - } - } + await handleChangeDeviceStateInput(runtime, { + command: parsed, + socket, + opts: { + onStatusChange: opts.onStatusChange, + connectionId: opts.connectionId, + }, + processQueuedTurn, + }); return; } @@ -1299,10 +1368,13 @@ export const __listenClientTestUtils = { shouldAttemptPostStopApprovalRecovery, getApprovalContinuationRecoveryDisposition, markAwaitingAcceptedApprovalContinuationRunId, + resolveStaleApprovals, normalizeMessageContentImages, normalizeInboundMessages, + consumeQueuedTurn, handleIncomingMessage, handleApprovalResponseInput, + handleChangeDeviceStateInput, scheduleQueuePump, recoverApprovalStateForSync, clearRecoveredApprovalStateForScope: ( diff --git a/src/websocket/listener/queue.ts b/src/websocket/listener/queue.ts index ba1fcaa..cbcc50b 100644 --- a/src/websocket/listener/queue.ts +++ b/src/websocket/listener/queue.ts @@ -6,6 +6,7 @@ import type { QueueBlockedReason, QueueItem, } from "../../queue/queueRuntime"; +import { isCoalescable } from "../../queue/queueRuntime"; import { mergeQueuedTurnInput } from "../../queue/turnQueueRuntime"; import { getListenerBlockedReason } from "../helpers/listenerQueueAdapter"; import { emitDequeuedUserMessage } from "./protocol-outbound"; @@ -53,6 +54,13 @@ export function getQueueItemsScope(items: QueueItem[]): { return sameScope ? getQueueItemScope(first) : {}; } +function hasSameQueueScope(a: QueueItem, b: QueueItem): boolean { + return ( + (a.agentId ?? null) === (b.agentId ?? null) && + (a.conversationId ?? null) === (b.conversationId ?? null) + ); +} + function mergeDequeuedBatchContent( items: QueueItem[], ): MessageCreate["content"] | null { @@ -246,6 +254,51 @@ export function shouldQueueInboundMessage(parsed: IncomingMessage): boolean { return parsed.messages.some((payload) => "content" in payload); } +export function consumeQueuedTurn(runtime: ConversationRuntime): { + dequeuedBatch: DequeuedBatch; + queuedTurn: IncomingMessage; +} | null { + const queuedItems = runtime.queueRuntime.peek(); + const firstQueuedItem = queuedItems[0]; + if (!firstQueuedItem || !isCoalescable(firstQueuedItem.kind)) { + return null; + } + + let queueLen = 0; + let hasMessage = false; + for (const item of queuedItems) { + if ( + !isCoalescable(item.kind) || + !hasSameQueueScope(firstQueuedItem, item) + ) { + break; + } + queueLen += 1; + if (item.kind === "message") { + hasMessage = true; + } + } + + if (!hasMessage || queueLen === 0) { + return null; + } + + const dequeuedBatch = runtime.queueRuntime.consumeItems(queueLen); + if (!dequeuedBatch) { + return null; + } + + const queuedTurn = buildQueuedTurnMessage(runtime, dequeuedBatch); + if (!queuedTurn) { + return null; + } + + return { + dequeuedBatch, + queuedTurn, + }; +} + function computeListenerQueueBlockedReason( runtime: ConversationRuntime, ): QueueBlockedReason | null { @@ -254,6 +307,7 @@ function computeListenerQueueBlockedReason( conversation_id: runtime.conversationId, }); return getListenerBlockedReason({ + loopStatus: runtime.loopStatus, isProcessing: runtime.isProcessing, pendingApprovalsLen: activeScope ? getPendingControlRequestCount(runtime.listener, activeScope) @@ -292,20 +346,12 @@ async function drainQueuedMessages( return; } - const queueLen = runtime.queueRuntime.length; - if (queueLen === 0) { + const consumedQueuedTurn = consumeQueuedTurn(runtime); + if (!consumedQueuedTurn) { return; } - const dequeuedBatch = runtime.queueRuntime.consumeItems(queueLen); - if (!dequeuedBatch) { - return; - } - - const queuedTurn = buildQueuedTurnMessage(runtime, dequeuedBatch); - if (!queuedTurn) { - continue; - } + const { dequeuedBatch, queuedTurn } = consumedQueuedTurn; emitDequeuedUserMessage(socket, runtime, queuedTurn, dequeuedBatch); diff --git a/src/websocket/listener/recovery.ts b/src/websocket/listener/recovery.ts index 9e37caa..6d91258 100644 --- a/src/websocket/listener/recovery.ts +++ b/src/websocket/listener/recovery.ts @@ -1,6 +1,10 @@ import { APIError } from "@letta-ai/letta-client/core/error"; import type { Stream } from "@letta-ai/letta-client/core/streaming"; -import type { LettaStreamingResponse } from "@letta-ai/letta-client/resources/agents/messages"; +import type { MessageCreate } from "@letta-ai/letta-client/resources/agents/agents"; +import type { + ApprovalCreate, + LettaStreamingResponse, +} from "@letta-ai/letta-client/resources/agents/messages"; import type WebSocket from "ws"; import { type ApprovalDecision, @@ -36,12 +40,14 @@ import { } from "./interrupts"; import { emitCanonicalMessageDelta, + emitDequeuedUserMessage, emitInterruptedStatusDelta, emitLoopErrorDelta, emitLoopStatusUpdate, emitRuntimeStateUpdates, setLoopStatus, } from "./protocol-outbound"; +import { consumeQueuedTurn } from "./queue"; import { clearActiveRunState, clearRecoveredApprovalState } from "./runtime"; import type { ConversationRuntime, @@ -560,23 +566,33 @@ export async function resolveRecoveredApprovalResponse( setLoopStatus(runtime, "SENDING_API_REQUEST", scope); emitRuntimeStateUpdates(runtime, scope); + const continuationMessages: Array = [ + { + type: "approval", + approvals: approvalResults, + }, + ]; + let continuationBatchId = `batch-recovered-${crypto.randomUUID()}`; + const consumedQueuedTurn = consumeQueuedTurn(runtime); + if (consumedQueuedTurn) { + const { dequeuedBatch, queuedTurn } = consumedQueuedTurn; + continuationBatchId = dequeuedBatch.batchId; + continuationMessages.push(...queuedTurn.messages); + emitDequeuedUserMessage(socket, runtime, queuedTurn, dequeuedBatch); + } + await processTurn( { type: "message", agentId: recovered.agentId, conversationId: recovered.conversationId, - messages: [ - { - type: "approval", - approvals: approvalResults, - }, - ], + messages: continuationMessages, }, socket, runtime, opts?.onStatusChange, opts?.connectionId, - `batch-recovered-${crypto.randomUUID()}`, + continuationBatchId, ); clearRecoveredApprovalState(runtime); diff --git a/src/websocket/listener/send.ts b/src/websocket/listener/send.ts index 2abe372..f103996 100644 --- a/src/websocket/listener/send.ts +++ b/src/websocket/listener/send.ts @@ -43,10 +43,12 @@ import { } from "./interrupts"; import { getConversationPermissionModeState } from "./permissionMode"; import { + emitDequeuedUserMessage, emitRetryDelta, emitRuntimeStateUpdates, setLoopStatus, } from "./protocol-outbound"; +import { consumeQueuedTurn } from "./queue"; import { drainRecoveryStreamWithEmission, finalizeHandledRecoveryTurn, @@ -80,13 +82,18 @@ export function markAwaitingAcceptedApprovalContinuationRunId( * and auto-denying. This is the Phase 3 bounded recovery mechanism — it does NOT * touch pendingInterruptedResults (that's exclusively owned by handleIncomingMessage). */ -async function resolveStaleApprovals( +export async function resolveStaleApprovals( runtime: ConversationRuntime, socket: WebSocket, abortSignal: AbortSignal, + deps: { + getResumeData?: typeof getResumeData; + } = {}, ): Promise> | null> { if (!runtime.agentId) return null; + const getResumeDataImpl = deps.getResumeData ?? getResumeData; + const client = await getClient(); let agent: Awaited>; try { @@ -102,9 +109,14 @@ async function resolveStaleApprovals( let resumeData: Awaited>; try { - resumeData = await getResumeData(client, agent, requestedConversationId, { - includeMessageHistory: false, - }); + resumeData = await getResumeDataImpl( + client, + agent, + requestedConversationId, + { + includeMessageHistory: false, + }, + ); } catch (err) { if (err instanceof APIError && (err.status === 404 || err.status === 422)) { return null; @@ -274,9 +286,22 @@ async function resolveStaleApprovals( "tool-return", ); + const continuationMessages: Array = [ + { + type: "approval", + approvals: approvalResults, + }, + ]; + const consumedQueuedTurn = consumeQueuedTurn(runtime); + if (consumedQueuedTurn) { + const { dequeuedBatch, queuedTurn } = consumedQueuedTurn; + continuationMessages.push(...queuedTurn.messages); + emitDequeuedUserMessage(socket, runtime, queuedTurn, dequeuedBatch); + } + const recoveryStream = await sendApprovalContinuationWithRetry( recoveryConversationId, - [{ type: "approval", approvals: approvalResults }], + continuationMessages, { agentId: runtime.agentId ?? undefined, streamTokens: true, @@ -294,6 +319,8 @@ async function resolveStaleApprovals( ); } + setLoopStatus(runtime, "PROCESSING_API_RESPONSE", scope); + const drainResult = await drainRecoveryStreamWithEmission( recoveryStream as Stream, socket, diff --git a/src/websocket/listener/turn-approval.ts b/src/websocket/listener/turn-approval.ts index d557c67..586a843 100644 --- a/src/websocket/listener/turn-approval.ts +++ b/src/websocket/listener/turn-approval.ts @@ -31,10 +31,12 @@ import { normalizeExecutionResultsForInterruptParity, } from "./interrupts"; import { + emitDequeuedUserMessage, emitLoopErrorDelta, emitRuntimeStateUpdates, setLoopStatus, } from "./protocol-outbound"; +import { consumeQueuedTurn } from "./queue"; import { debugLogApprovalResumeState } from "./recovery"; import { markAwaitingAcceptedApprovalContinuationRunId, @@ -66,6 +68,7 @@ export type ApprovalBranchResult = { terminated: boolean; stream: Stream | null; currentInput: Array; + dequeuedBatchId: string; pendingNormalizationInterruptedToolCallIds: string[]; turnToolContextId: string | null; lastExecutionResults: ApprovalResult[] | null; @@ -144,6 +147,7 @@ export async function handleApprovalStop(params: { terminated: true, stream: null, currentInput, + dequeuedBatchId, pendingNormalizationInterruptedToolCallIds: [], turnToolContextId, lastExecutionResults: null, @@ -244,11 +248,10 @@ export async function handleApprovalStop(params: { }); } } else { - const denyReason = responseBody.error; decisions.push({ type: "deny", approval: ac.approval, - reason: denyReason, + reason: responseBody.error, }); } } @@ -270,9 +273,7 @@ export async function handleApprovalStop(params: { conversation_id: conversationId, }); const executionRunId = - runId || - runtime.activeRunId || - params.msgRunIds[params.msgRunIds.length - 1]; + runId || runtime.activeRunId || msgRunIds[msgRunIds.length - 1]; emitToolExecutionStartedEvents(socket, runtime, { toolCallIds: lastExecutingToolCallIds, runId: executionRunId, @@ -315,12 +316,22 @@ export async function handleApprovalStop(params: { undefined, "tool-return", ); + const nextInput: Array = [ { type: "approval", approvals: persistedExecutionResults, }, ]; + let continuationBatchId = dequeuedBatchId; + const consumedQueuedTurn = consumeQueuedTurn(runtime); + if (consumedQueuedTurn) { + const { dequeuedBatch, queuedTurn } = consumedQueuedTurn; + continuationBatchId = dequeuedBatch.batchId; + nextInput.push(...queuedTurn.messages); + emitDequeuedUserMessage(socket, runtime, queuedTurn, dequeuedBatch); + } + setLoopStatus(runtime, "SENDING_API_REQUEST", { agent_id: agentId, conversation_id: conversationId, @@ -338,6 +349,7 @@ export async function handleApprovalStop(params: { terminated: true, stream: null, currentInput: nextInput, + dequeuedBatchId: continuationBatchId, pendingNormalizationInterruptedToolCallIds: [], turnToolContextId, lastExecutionResults, @@ -346,6 +358,7 @@ export async function handleApprovalStop(params: { lastApprovalContinuationAccepted: false, }; } + clearPendingApprovalBatchIds( runtime, decisions.map((decision) => decision.approval), @@ -380,6 +393,7 @@ export async function handleApprovalStop(params: { terminated: false, stream, currentInput: nextInput, + dequeuedBatchId: continuationBatchId, pendingNormalizationInterruptedToolCallIds: [], turnToolContextId: null, lastExecutionResults, diff --git a/src/websocket/listener/turn.ts b/src/websocket/listener/turn.ts index c3e7cb3..17612b5 100644 --- a/src/websocket/listener/turn.ts +++ b/src/websocket/listener/turn.ts @@ -113,6 +113,7 @@ export async function handleIncomingMessage( let llmApiErrorRetries = 0; let emptyResponseRetries = 0; let lastApprovalContinuationAccepted = false; + let activeDequeuedBatchId = dequeuedBatchId; let lastExecutionResults: ApprovalResult[] | null = null; let lastExecutingToolCallIds: string[] = []; @@ -660,7 +661,7 @@ export async function handleIncomingMessage( conversationId, turnWorkingDirectory, turnPermissionModeState, - dequeuedBatchId, + dequeuedBatchId: activeDequeuedBatchId, runId, msgRunIds, currentInput, @@ -673,6 +674,7 @@ export async function handleIncomingMessage( } stream = approvalResult.stream; currentInput = approvalResult.currentInput; + activeDequeuedBatchId = approvalResult.dequeuedBatchId; pendingNormalizationInterruptedToolCallIds = approvalResult.pendingNormalizationInterruptedToolCallIds; turnToolContextId = approvalResult.turnToolContextId;