From ce5ba0496f3bb43b9d5a58308b34465a48823cc0 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Wed, 25 Feb 2026 19:05:43 -0800 Subject: [PATCH] feat(listen): add ws control_response resolver plumbing (#1152) --- .../websocket/listen-client-protocol.test.ts | 239 ++++++++++++++++++ src/websocket/listen-client.ts | 122 ++++++++- 2 files changed, 358 insertions(+), 3 deletions(-) create mode 100644 src/tests/websocket/listen-client-protocol.test.ts diff --git a/src/tests/websocket/listen-client-protocol.test.ts b/src/tests/websocket/listen-client-protocol.test.ts new file mode 100644 index 0000000..056e2d5 --- /dev/null +++ b/src/tests/websocket/listen-client-protocol.test.ts @@ -0,0 +1,239 @@ +import { describe, expect, test } from "bun:test"; +import WebSocket from "ws"; +import type { ControlRequest, ControlResponseBody } from "../../types/protocol"; +import { + __listenClientTestUtils, + parseServerMessage, + rejectPendingApprovalResolvers, + requestApprovalOverWS, + resolvePendingApprovalResolver, +} from "../../websocket/listen-client"; + +class MockSocket { + readyState: number; + closeCalls = 0; + removeAllListenersCalls = 0; + sentPayloads: string[] = []; + sendImpl: (data: string) => void = (data) => { + this.sentPayloads.push(data); + }; + + constructor(readyState: number = WebSocket.OPEN) { + this.readyState = readyState; + } + + send(data: string): void { + this.sendImpl(data); + } + + close(): void { + this.closeCalls += 1; + } + + removeAllListeners(): this { + this.removeAllListenersCalls += 1; + return this; + } +} + +function makeControlRequest(requestId: string): ControlRequest { + return { + type: "control_request", + request_id: requestId, + request: { + subtype: "can_use_tool", + tool_name: "Write", + input: {}, + tool_call_id: "call-1", + permission_suggestions: [], + blocked_path: null, + }, + }; +} + +function makeSuccessResponse(requestId: string): ControlResponseBody { + return { + subtype: "success", + request_id: requestId, + response: { behavior: "allow" }, + }; +} + +describe("listen-client parseServerMessage", () => { + test("parses valid control_response with required fields", () => { + const parsed = parseServerMessage( + Buffer.from( + JSON.stringify({ + type: "control_response", + response: { subtype: "success", request_id: "perm-1" }, + }), + ), + ); + expect(parsed).not.toBeNull(); + expect(parsed?.type).toBe("control_response"); + }); + + test("rejects invalid control_response payloads", () => { + const missingResponse = parseServerMessage( + Buffer.from(JSON.stringify({ type: "control_response" })), + ); + expect(missingResponse).toBeNull(); + + const missingRequestId = parseServerMessage( + Buffer.from( + JSON.stringify({ + type: "control_response", + response: { subtype: "success" }, + }), + ), + ); + expect(missingRequestId).toBeNull(); + }); + + test("keeps backward compatibility for message, pong, mode_change", () => { + const msg = parseServerMessage( + Buffer.from(JSON.stringify({ type: "message", messages: [] })), + ); + const pong = parseServerMessage( + Buffer.from(JSON.stringify({ type: "pong" })), + ); + const modeChange = parseServerMessage( + Buffer.from(JSON.stringify({ type: "mode_change", mode: "default" })), + ); + expect(msg?.type).toBe("message"); + expect(pong?.type).toBe("pong"); + expect(modeChange?.type).toBe("mode_change"); + }); +}); + +describe("listen-client approval resolver wiring", () => { + test("resolves matching pending resolver", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.OPEN); + const requestId = "perm-101"; + + const pending = requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + requestId, + makeControlRequest(requestId), + ); + expect(runtime.pendingApprovalResolvers.size).toBe(1); + + const resolved = resolvePendingApprovalResolver( + runtime, + makeSuccessResponse(requestId), + ); + expect(resolved).toBe(true); + await expect(pending).resolves.toMatchObject({ + subtype: "success", + request_id: requestId, + }); + expect(runtime.pendingApprovalResolvers.size).toBe(0); + }); + + test("ignores non-matching request_id and keeps pending resolver", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.OPEN); + const requestId = "perm-201"; + + const pending = requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + requestId, + makeControlRequest(requestId), + ); + let settled = false; + void pending.then( + () => { + settled = true; + }, + () => { + settled = true; + }, + ); + + const resolved = resolvePendingApprovalResolver( + runtime, + makeSuccessResponse("perm-other"), + ); + expect(resolved).toBe(false); + await Promise.resolve(); + expect(settled).toBe(false); + expect(runtime.pendingApprovalResolvers.size).toBe(1); + + const handledPending = pending.catch((error) => error); + rejectPendingApprovalResolvers(runtime, "cleanup"); + const cleanupError = await handledPending; + expect(cleanupError).toBeInstanceOf(Error); + expect((cleanupError as Error).message).toBe("cleanup"); + }); + + test("cleanup rejects all pending resolvers", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const first = new Promise((resolve, reject) => { + runtime.pendingApprovalResolvers.set("perm-a", { resolve, reject }); + }); + const second = new Promise((resolve, reject) => { + runtime.pendingApprovalResolvers.set("perm-b", { resolve, reject }); + }); + + rejectPendingApprovalResolvers(runtime, "socket closed"); + expect(runtime.pendingApprovalResolvers.size).toBe(0); + await expect(first).rejects.toThrow("socket closed"); + await expect(second).rejects.toThrow("socket closed"); + }); + + test("stopRuntime rejects pending resolvers even when callbacks are suppressed", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const pending = new Promise((resolve, reject) => { + runtime.pendingApprovalResolvers.set("perm-stop", { resolve, reject }); + }); + const socket = new MockSocket(WebSocket.OPEN); + runtime.socket = socket as unknown as WebSocket; + + __listenClientTestUtils.stopRuntime(runtime, true); + + expect(runtime.pendingApprovalResolvers.size).toBe(0); + expect(socket.removeAllListenersCalls).toBe(1); + expect(socket.closeCalls).toBe(1); + await expect(pending).rejects.toThrow("Listener runtime stopped"); + }); +}); + +describe("listen-client requestApprovalOverWS", () => { + test("rejects immediately when socket is not open", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.CLOSED); + const requestId = "perm-closed"; + + await expect( + requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + requestId, + makeControlRequest(requestId), + ), + ).rejects.toThrow("WebSocket not open"); + expect(runtime.pendingApprovalResolvers.size).toBe(0); + }); + + test("cleans up resolver when send throws", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.OPEN); + socket.sendImpl = () => { + throw new Error("send failed"); + }; + const requestId = "perm-send-fail"; + + await expect( + requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + requestId, + makeControlRequest(requestId), + ), + ).rejects.toThrow("send failed"); + expect(runtime.pendingApprovalResolvers.size).toBe(0); + }); +}); diff --git a/src/websocket/listen-client.ts b/src/websocket/listen-client.ts index c5f7326..bce93d4 100644 --- a/src/websocket/listen-client.ts +++ b/src/websocket/listen-client.ts @@ -27,6 +27,7 @@ import { permissionMode } from "../permissions/mode"; import { settingsManager } from "../settings-manager"; import { isInteractiveApprovalTool } from "../tools/interactivePolicy"; import { loadTools } from "../tools/manager"; +import type { ControlRequest, ControlResponseBody } from "../types/protocol"; interface StartListenerOptions { connectionId: string; @@ -79,6 +80,11 @@ interface ModeChangeMessage { mode: "default" | "acceptEdits" | "plan" | "bypassPermissions"; } +interface WsControlResponse { + type: "control_response"; + response: ControlResponseBody; +} + interface ModeChangedMessage { type: "mode_changed"; mode: "default" | "acceptEdits" | "plan" | "bypassPermissions"; @@ -86,13 +92,22 @@ interface ModeChangedMessage { error?: string; } -type ServerMessage = PongMessage | IncomingMessage | ModeChangeMessage; +type ServerMessage = + | PongMessage + | IncomingMessage + | ModeChangeMessage + | WsControlResponse; type ClientMessage = | PingMessage | ResultMessage | RunStartedMessage | ModeChangedMessage; +type PendingApprovalResolver = { + resolve: (response: ControlResponseBody) => void; + reject: (reason: Error) => void; +}; + type ListenerRuntime = { socket: WebSocket | null; heartbeatInterval: NodeJS.Timeout | null; @@ -100,6 +115,7 @@ type ListenerRuntime = { intentionallyClosed: boolean; hasSuccessfulConnection: boolean; messageQueue: Promise; + pendingApprovalResolvers: Map; }; type ApprovalSlot = @@ -159,6 +175,7 @@ function createRuntime(): ListenerRuntime { intentionallyClosed: false, hasSuccessfulConnection: false, messageQueue: Promise.resolve(), + pendingApprovalResolvers: new Map(), }; } @@ -179,6 +196,7 @@ function stopRuntime( ): void { runtime.intentionallyClosed = true; clearRuntimeTimers(runtime); + rejectPendingApprovalResolvers(runtime, "Listener runtime stopped"); if (!runtime.socket) { return; @@ -200,10 +218,29 @@ function stopRuntime( } } -function parseServerMessage(data: WebSocket.RawData): ServerMessage | null { +function isValidControlResponseBody( + value: unknown, +): value is ControlResponseBody { + if (!value || typeof value !== "object") { + return false; + } + + const maybeResponse = value as { + subtype?: unknown; + request_id?: unknown; + }; + return ( + typeof maybeResponse.subtype === "string" && + typeof maybeResponse.request_id === "string" + ); +} + +export function parseServerMessage( + data: WebSocket.RawData, +): ServerMessage | null { try { const raw = typeof data === "string" ? data : data.toString(); - const parsed = JSON.parse(raw) as { type?: string }; + const parsed = JSON.parse(raw) as { type?: string; response?: unknown }; if ( parsed.type === "pong" || parsed.type === "message" || @@ -211,6 +248,12 @@ function parseServerMessage(data: WebSocket.RawData): ServerMessage | null { ) { return parsed as ServerMessage; } + if ( + parsed.type === "control_response" && + isValidControlResponseBody(parsed.response) + ) { + return parsed as ServerMessage; + } return null; } catch { return null; @@ -223,6 +266,65 @@ function sendClientMessage(socket: WebSocket, payload: ClientMessage): void { } } +function sendControlMessageOverWebSocket( + socket: WebSocket, + payload: ControlRequest, +): void { + // Central hook for protocol-only outbound WS messages so future + // filtering/mutation can be added without touching approval flow. + socket.send(JSON.stringify(payload)); +} + +export function resolvePendingApprovalResolver( + runtime: ListenerRuntime, + response: ControlResponseBody, +): boolean { + const requestId = response.request_id; + if (typeof requestId !== "string" || requestId.length === 0) { + return false; + } + + const pending = runtime.pendingApprovalResolvers.get(requestId); + if (!pending) { + return false; + } + + runtime.pendingApprovalResolvers.delete(requestId); + pending.resolve(response); + return true; +} + +export function rejectPendingApprovalResolvers( + runtime: ListenerRuntime, + reason: string, +): void { + for (const [, pending] of runtime.pendingApprovalResolvers) { + pending.reject(new Error(reason)); + } + runtime.pendingApprovalResolvers.clear(); +} + +export function requestApprovalOverWS( + runtime: ListenerRuntime, + socket: WebSocket, + requestId: string, + controlRequest: ControlRequest, +): Promise { + if (socket.readyState !== WebSocket.OPEN) { + return Promise.reject(new Error("WebSocket not open")); + } + + return new Promise((resolve, reject) => { + runtime.pendingApprovalResolvers.set(requestId, { resolve, reject }); + try { + sendControlMessageOverWebSocket(socket, controlRequest); + } catch (error) { + runtime.pendingApprovalResolvers.delete(requestId); + reject(error instanceof Error ? error : new Error(String(error))); + } + }); +} + function buildApprovalExecutionPlan( approvalMessage: ApprovalCreate, pendingApprovals: Array<{ @@ -415,6 +517,14 @@ async function connectWithRetry( return; } + if (parsed.type === "control_response") { + if (runtime !== activeRuntime || runtime.intentionallyClosed) { + return; + } + resolvePendingApprovalResolver(runtime, parsed.response); + return; + } + // Handle mode change messages immediately (not queued) if (parsed.type === "mode_change") { handleModeChange(parsed, socket); @@ -460,6 +570,7 @@ async function connectWithRetry( clearRuntimeTimers(runtime); runtime.socket = null; + rejectPendingApprovalResolvers(runtime, "WebSocket disconnected"); if (runtime.intentionallyClosed) { opts.onDisconnected(); @@ -758,3 +869,8 @@ export function stopListenerClient(): void { activeRuntime = null; stopRuntime(runtime, true); } + +export const __listenClientTestUtils = { + createRuntime, + stopRuntime, +};