feat(listen): add ws control_response resolver plumbing (#1152)

This commit is contained in:
Charles Packer
2026-02-25 19:05:43 -08:00
committed by GitHub
parent 3b2b9ca776
commit ce5ba0496f
2 changed files with 358 additions and 3 deletions

View File

@@ -0,0 +1,239 @@
import { describe, expect, test } from "bun:test";
import WebSocket from "ws";
import type { ControlRequest, ControlResponseBody } from "../../types/protocol";
import {
__listenClientTestUtils,
parseServerMessage,
rejectPendingApprovalResolvers,
requestApprovalOverWS,
resolvePendingApprovalResolver,
} from "../../websocket/listen-client";
class MockSocket {
readyState: number;
closeCalls = 0;
removeAllListenersCalls = 0;
sentPayloads: string[] = [];
sendImpl: (data: string) => void = (data) => {
this.sentPayloads.push(data);
};
constructor(readyState: number = WebSocket.OPEN) {
this.readyState = readyState;
}
send(data: string): void {
this.sendImpl(data);
}
close(): void {
this.closeCalls += 1;
}
removeAllListeners(): this {
this.removeAllListenersCalls += 1;
return this;
}
}
function makeControlRequest(requestId: string): ControlRequest {
return {
type: "control_request",
request_id: requestId,
request: {
subtype: "can_use_tool",
tool_name: "Write",
input: {},
tool_call_id: "call-1",
permission_suggestions: [],
blocked_path: null,
},
};
}
function makeSuccessResponse(requestId: string): ControlResponseBody {
return {
subtype: "success",
request_id: requestId,
response: { behavior: "allow" },
};
}
describe("listen-client parseServerMessage", () => {
test("parses valid control_response with required fields", () => {
const parsed = parseServerMessage(
Buffer.from(
JSON.stringify({
type: "control_response",
response: { subtype: "success", request_id: "perm-1" },
}),
),
);
expect(parsed).not.toBeNull();
expect(parsed?.type).toBe("control_response");
});
test("rejects invalid control_response payloads", () => {
const missingResponse = parseServerMessage(
Buffer.from(JSON.stringify({ type: "control_response" })),
);
expect(missingResponse).toBeNull();
const missingRequestId = parseServerMessage(
Buffer.from(
JSON.stringify({
type: "control_response",
response: { subtype: "success" },
}),
),
);
expect(missingRequestId).toBeNull();
});
test("keeps backward compatibility for message, pong, mode_change", () => {
const msg = parseServerMessage(
Buffer.from(JSON.stringify({ type: "message", messages: [] })),
);
const pong = parseServerMessage(
Buffer.from(JSON.stringify({ type: "pong" })),
);
const modeChange = parseServerMessage(
Buffer.from(JSON.stringify({ type: "mode_change", mode: "default" })),
);
expect(msg?.type).toBe("message");
expect(pong?.type).toBe("pong");
expect(modeChange?.type).toBe("mode_change");
});
});
describe("listen-client approval resolver wiring", () => {
test("resolves matching pending resolver", async () => {
const runtime = __listenClientTestUtils.createRuntime();
const socket = new MockSocket(WebSocket.OPEN);
const requestId = "perm-101";
const pending = requestApprovalOverWS(
runtime,
socket as unknown as WebSocket,
requestId,
makeControlRequest(requestId),
);
expect(runtime.pendingApprovalResolvers.size).toBe(1);
const resolved = resolvePendingApprovalResolver(
runtime,
makeSuccessResponse(requestId),
);
expect(resolved).toBe(true);
await expect(pending).resolves.toMatchObject({
subtype: "success",
request_id: requestId,
});
expect(runtime.pendingApprovalResolvers.size).toBe(0);
});
test("ignores non-matching request_id and keeps pending resolver", async () => {
const runtime = __listenClientTestUtils.createRuntime();
const socket = new MockSocket(WebSocket.OPEN);
const requestId = "perm-201";
const pending = requestApprovalOverWS(
runtime,
socket as unknown as WebSocket,
requestId,
makeControlRequest(requestId),
);
let settled = false;
void pending.then(
() => {
settled = true;
},
() => {
settled = true;
},
);
const resolved = resolvePendingApprovalResolver(
runtime,
makeSuccessResponse("perm-other"),
);
expect(resolved).toBe(false);
await Promise.resolve();
expect(settled).toBe(false);
expect(runtime.pendingApprovalResolvers.size).toBe(1);
const handledPending = pending.catch((error) => error);
rejectPendingApprovalResolvers(runtime, "cleanup");
const cleanupError = await handledPending;
expect(cleanupError).toBeInstanceOf(Error);
expect((cleanupError as Error).message).toBe("cleanup");
});
test("cleanup rejects all pending resolvers", async () => {
const runtime = __listenClientTestUtils.createRuntime();
const first = new Promise<ControlResponseBody>((resolve, reject) => {
runtime.pendingApprovalResolvers.set("perm-a", { resolve, reject });
});
const second = new Promise<ControlResponseBody>((resolve, reject) => {
runtime.pendingApprovalResolvers.set("perm-b", { resolve, reject });
});
rejectPendingApprovalResolvers(runtime, "socket closed");
expect(runtime.pendingApprovalResolvers.size).toBe(0);
await expect(first).rejects.toThrow("socket closed");
await expect(second).rejects.toThrow("socket closed");
});
test("stopRuntime rejects pending resolvers even when callbacks are suppressed", async () => {
const runtime = __listenClientTestUtils.createRuntime();
const pending = new Promise<ControlResponseBody>((resolve, reject) => {
runtime.pendingApprovalResolvers.set("perm-stop", { resolve, reject });
});
const socket = new MockSocket(WebSocket.OPEN);
runtime.socket = socket as unknown as WebSocket;
__listenClientTestUtils.stopRuntime(runtime, true);
expect(runtime.pendingApprovalResolvers.size).toBe(0);
expect(socket.removeAllListenersCalls).toBe(1);
expect(socket.closeCalls).toBe(1);
await expect(pending).rejects.toThrow("Listener runtime stopped");
});
});
describe("listen-client requestApprovalOverWS", () => {
test("rejects immediately when socket is not open", async () => {
const runtime = __listenClientTestUtils.createRuntime();
const socket = new MockSocket(WebSocket.CLOSED);
const requestId = "perm-closed";
await expect(
requestApprovalOverWS(
runtime,
socket as unknown as WebSocket,
requestId,
makeControlRequest(requestId),
),
).rejects.toThrow("WebSocket not open");
expect(runtime.pendingApprovalResolvers.size).toBe(0);
});
test("cleans up resolver when send throws", async () => {
const runtime = __listenClientTestUtils.createRuntime();
const socket = new MockSocket(WebSocket.OPEN);
socket.sendImpl = () => {
throw new Error("send failed");
};
const requestId = "perm-send-fail";
await expect(
requestApprovalOverWS(
runtime,
socket as unknown as WebSocket,
requestId,
makeControlRequest(requestId),
),
).rejects.toThrow("send failed");
expect(runtime.pendingApprovalResolvers.size).toBe(0);
});
});

View File

@@ -27,6 +27,7 @@ import { permissionMode } from "../permissions/mode";
import { settingsManager } from "../settings-manager";
import { isInteractiveApprovalTool } from "../tools/interactivePolicy";
import { loadTools } from "../tools/manager";
import type { ControlRequest, ControlResponseBody } from "../types/protocol";
interface StartListenerOptions {
connectionId: string;
@@ -79,6 +80,11 @@ interface ModeChangeMessage {
mode: "default" | "acceptEdits" | "plan" | "bypassPermissions";
}
interface WsControlResponse {
type: "control_response";
response: ControlResponseBody;
}
interface ModeChangedMessage {
type: "mode_changed";
mode: "default" | "acceptEdits" | "plan" | "bypassPermissions";
@@ -86,13 +92,22 @@ interface ModeChangedMessage {
error?: string;
}
type ServerMessage = PongMessage | IncomingMessage | ModeChangeMessage;
type ServerMessage =
| PongMessage
| IncomingMessage
| ModeChangeMessage
| WsControlResponse;
type ClientMessage =
| PingMessage
| ResultMessage
| RunStartedMessage
| ModeChangedMessage;
type PendingApprovalResolver = {
resolve: (response: ControlResponseBody) => void;
reject: (reason: Error) => void;
};
type ListenerRuntime = {
socket: WebSocket | null;
heartbeatInterval: NodeJS.Timeout | null;
@@ -100,6 +115,7 @@ type ListenerRuntime = {
intentionallyClosed: boolean;
hasSuccessfulConnection: boolean;
messageQueue: Promise<void>;
pendingApprovalResolvers: Map<string, PendingApprovalResolver>;
};
type ApprovalSlot =
@@ -159,6 +175,7 @@ function createRuntime(): ListenerRuntime {
intentionallyClosed: false,
hasSuccessfulConnection: false,
messageQueue: Promise.resolve(),
pendingApprovalResolvers: new Map(),
};
}
@@ -179,6 +196,7 @@ function stopRuntime(
): void {
runtime.intentionallyClosed = true;
clearRuntimeTimers(runtime);
rejectPendingApprovalResolvers(runtime, "Listener runtime stopped");
if (!runtime.socket) {
return;
@@ -200,10 +218,29 @@ function stopRuntime(
}
}
function parseServerMessage(data: WebSocket.RawData): ServerMessage | null {
function isValidControlResponseBody(
value: unknown,
): value is ControlResponseBody {
if (!value || typeof value !== "object") {
return false;
}
const maybeResponse = value as {
subtype?: unknown;
request_id?: unknown;
};
return (
typeof maybeResponse.subtype === "string" &&
typeof maybeResponse.request_id === "string"
);
}
export function parseServerMessage(
data: WebSocket.RawData,
): ServerMessage | null {
try {
const raw = typeof data === "string" ? data : data.toString();
const parsed = JSON.parse(raw) as { type?: string };
const parsed = JSON.parse(raw) as { type?: string; response?: unknown };
if (
parsed.type === "pong" ||
parsed.type === "message" ||
@@ -211,6 +248,12 @@ function parseServerMessage(data: WebSocket.RawData): ServerMessage | null {
) {
return parsed as ServerMessage;
}
if (
parsed.type === "control_response" &&
isValidControlResponseBody(parsed.response)
) {
return parsed as ServerMessage;
}
return null;
} catch {
return null;
@@ -223,6 +266,65 @@ function sendClientMessage(socket: WebSocket, payload: ClientMessage): void {
}
}
function sendControlMessageOverWebSocket(
socket: WebSocket,
payload: ControlRequest,
): void {
// Central hook for protocol-only outbound WS messages so future
// filtering/mutation can be added without touching approval flow.
socket.send(JSON.stringify(payload));
}
export function resolvePendingApprovalResolver(
runtime: ListenerRuntime,
response: ControlResponseBody,
): boolean {
const requestId = response.request_id;
if (typeof requestId !== "string" || requestId.length === 0) {
return false;
}
const pending = runtime.pendingApprovalResolvers.get(requestId);
if (!pending) {
return false;
}
runtime.pendingApprovalResolvers.delete(requestId);
pending.resolve(response);
return true;
}
export function rejectPendingApprovalResolvers(
runtime: ListenerRuntime,
reason: string,
): void {
for (const [, pending] of runtime.pendingApprovalResolvers) {
pending.reject(new Error(reason));
}
runtime.pendingApprovalResolvers.clear();
}
export function requestApprovalOverWS(
runtime: ListenerRuntime,
socket: WebSocket,
requestId: string,
controlRequest: ControlRequest,
): Promise<ControlResponseBody> {
if (socket.readyState !== WebSocket.OPEN) {
return Promise.reject(new Error("WebSocket not open"));
}
return new Promise<ControlResponseBody>((resolve, reject) => {
runtime.pendingApprovalResolvers.set(requestId, { resolve, reject });
try {
sendControlMessageOverWebSocket(socket, controlRequest);
} catch (error) {
runtime.pendingApprovalResolvers.delete(requestId);
reject(error instanceof Error ? error : new Error(String(error)));
}
});
}
function buildApprovalExecutionPlan(
approvalMessage: ApprovalCreate,
pendingApprovals: Array<{
@@ -415,6 +517,14 @@ async function connectWithRetry(
return;
}
if (parsed.type === "control_response") {
if (runtime !== activeRuntime || runtime.intentionallyClosed) {
return;
}
resolvePendingApprovalResolver(runtime, parsed.response);
return;
}
// Handle mode change messages immediately (not queued)
if (parsed.type === "mode_change") {
handleModeChange(parsed, socket);
@@ -460,6 +570,7 @@ async function connectWithRetry(
clearRuntimeTimers(runtime);
runtime.socket = null;
rejectPendingApprovalResolvers(runtime, "WebSocket disconnected");
if (runtime.intentionallyClosed) {
opts.onDisconnected();
@@ -758,3 +869,8 @@ export function stopListenerClient(): void {
activeRuntime = null;
stopRuntime(runtime, true);
}
export const __listenClientTestUtils = {
createRuntime,
stopRuntime,
};