feat(listen): add ws control_response resolver plumbing (#1152)
This commit is contained in:
239
src/tests/websocket/listen-client-protocol.test.ts
Normal file
239
src/tests/websocket/listen-client-protocol.test.ts
Normal file
@@ -0,0 +1,239 @@
|
||||
import { describe, expect, test } from "bun:test";
|
||||
import WebSocket from "ws";
|
||||
import type { ControlRequest, ControlResponseBody } from "../../types/protocol";
|
||||
import {
|
||||
__listenClientTestUtils,
|
||||
parseServerMessage,
|
||||
rejectPendingApprovalResolvers,
|
||||
requestApprovalOverWS,
|
||||
resolvePendingApprovalResolver,
|
||||
} from "../../websocket/listen-client";
|
||||
|
||||
class MockSocket {
|
||||
readyState: number;
|
||||
closeCalls = 0;
|
||||
removeAllListenersCalls = 0;
|
||||
sentPayloads: string[] = [];
|
||||
sendImpl: (data: string) => void = (data) => {
|
||||
this.sentPayloads.push(data);
|
||||
};
|
||||
|
||||
constructor(readyState: number = WebSocket.OPEN) {
|
||||
this.readyState = readyState;
|
||||
}
|
||||
|
||||
send(data: string): void {
|
||||
this.sendImpl(data);
|
||||
}
|
||||
|
||||
close(): void {
|
||||
this.closeCalls += 1;
|
||||
}
|
||||
|
||||
removeAllListeners(): this {
|
||||
this.removeAllListenersCalls += 1;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
function makeControlRequest(requestId: string): ControlRequest {
|
||||
return {
|
||||
type: "control_request",
|
||||
request_id: requestId,
|
||||
request: {
|
||||
subtype: "can_use_tool",
|
||||
tool_name: "Write",
|
||||
input: {},
|
||||
tool_call_id: "call-1",
|
||||
permission_suggestions: [],
|
||||
blocked_path: null,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function makeSuccessResponse(requestId: string): ControlResponseBody {
|
||||
return {
|
||||
subtype: "success",
|
||||
request_id: requestId,
|
||||
response: { behavior: "allow" },
|
||||
};
|
||||
}
|
||||
|
||||
describe("listen-client parseServerMessage", () => {
|
||||
test("parses valid control_response with required fields", () => {
|
||||
const parsed = parseServerMessage(
|
||||
Buffer.from(
|
||||
JSON.stringify({
|
||||
type: "control_response",
|
||||
response: { subtype: "success", request_id: "perm-1" },
|
||||
}),
|
||||
),
|
||||
);
|
||||
expect(parsed).not.toBeNull();
|
||||
expect(parsed?.type).toBe("control_response");
|
||||
});
|
||||
|
||||
test("rejects invalid control_response payloads", () => {
|
||||
const missingResponse = parseServerMessage(
|
||||
Buffer.from(JSON.stringify({ type: "control_response" })),
|
||||
);
|
||||
expect(missingResponse).toBeNull();
|
||||
|
||||
const missingRequestId = parseServerMessage(
|
||||
Buffer.from(
|
||||
JSON.stringify({
|
||||
type: "control_response",
|
||||
response: { subtype: "success" },
|
||||
}),
|
||||
),
|
||||
);
|
||||
expect(missingRequestId).toBeNull();
|
||||
});
|
||||
|
||||
test("keeps backward compatibility for message, pong, mode_change", () => {
|
||||
const msg = parseServerMessage(
|
||||
Buffer.from(JSON.stringify({ type: "message", messages: [] })),
|
||||
);
|
||||
const pong = parseServerMessage(
|
||||
Buffer.from(JSON.stringify({ type: "pong" })),
|
||||
);
|
||||
const modeChange = parseServerMessage(
|
||||
Buffer.from(JSON.stringify({ type: "mode_change", mode: "default" })),
|
||||
);
|
||||
expect(msg?.type).toBe("message");
|
||||
expect(pong?.type).toBe("pong");
|
||||
expect(modeChange?.type).toBe("mode_change");
|
||||
});
|
||||
});
|
||||
|
||||
describe("listen-client approval resolver wiring", () => {
|
||||
test("resolves matching pending resolver", async () => {
|
||||
const runtime = __listenClientTestUtils.createRuntime();
|
||||
const socket = new MockSocket(WebSocket.OPEN);
|
||||
const requestId = "perm-101";
|
||||
|
||||
const pending = requestApprovalOverWS(
|
||||
runtime,
|
||||
socket as unknown as WebSocket,
|
||||
requestId,
|
||||
makeControlRequest(requestId),
|
||||
);
|
||||
expect(runtime.pendingApprovalResolvers.size).toBe(1);
|
||||
|
||||
const resolved = resolvePendingApprovalResolver(
|
||||
runtime,
|
||||
makeSuccessResponse(requestId),
|
||||
);
|
||||
expect(resolved).toBe(true);
|
||||
await expect(pending).resolves.toMatchObject({
|
||||
subtype: "success",
|
||||
request_id: requestId,
|
||||
});
|
||||
expect(runtime.pendingApprovalResolvers.size).toBe(0);
|
||||
});
|
||||
|
||||
test("ignores non-matching request_id and keeps pending resolver", async () => {
|
||||
const runtime = __listenClientTestUtils.createRuntime();
|
||||
const socket = new MockSocket(WebSocket.OPEN);
|
||||
const requestId = "perm-201";
|
||||
|
||||
const pending = requestApprovalOverWS(
|
||||
runtime,
|
||||
socket as unknown as WebSocket,
|
||||
requestId,
|
||||
makeControlRequest(requestId),
|
||||
);
|
||||
let settled = false;
|
||||
void pending.then(
|
||||
() => {
|
||||
settled = true;
|
||||
},
|
||||
() => {
|
||||
settled = true;
|
||||
},
|
||||
);
|
||||
|
||||
const resolved = resolvePendingApprovalResolver(
|
||||
runtime,
|
||||
makeSuccessResponse("perm-other"),
|
||||
);
|
||||
expect(resolved).toBe(false);
|
||||
await Promise.resolve();
|
||||
expect(settled).toBe(false);
|
||||
expect(runtime.pendingApprovalResolvers.size).toBe(1);
|
||||
|
||||
const handledPending = pending.catch((error) => error);
|
||||
rejectPendingApprovalResolvers(runtime, "cleanup");
|
||||
const cleanupError = await handledPending;
|
||||
expect(cleanupError).toBeInstanceOf(Error);
|
||||
expect((cleanupError as Error).message).toBe("cleanup");
|
||||
});
|
||||
|
||||
test("cleanup rejects all pending resolvers", async () => {
|
||||
const runtime = __listenClientTestUtils.createRuntime();
|
||||
const first = new Promise<ControlResponseBody>((resolve, reject) => {
|
||||
runtime.pendingApprovalResolvers.set("perm-a", { resolve, reject });
|
||||
});
|
||||
const second = new Promise<ControlResponseBody>((resolve, reject) => {
|
||||
runtime.pendingApprovalResolvers.set("perm-b", { resolve, reject });
|
||||
});
|
||||
|
||||
rejectPendingApprovalResolvers(runtime, "socket closed");
|
||||
expect(runtime.pendingApprovalResolvers.size).toBe(0);
|
||||
await expect(first).rejects.toThrow("socket closed");
|
||||
await expect(second).rejects.toThrow("socket closed");
|
||||
});
|
||||
|
||||
test("stopRuntime rejects pending resolvers even when callbacks are suppressed", async () => {
|
||||
const runtime = __listenClientTestUtils.createRuntime();
|
||||
const pending = new Promise<ControlResponseBody>((resolve, reject) => {
|
||||
runtime.pendingApprovalResolvers.set("perm-stop", { resolve, reject });
|
||||
});
|
||||
const socket = new MockSocket(WebSocket.OPEN);
|
||||
runtime.socket = socket as unknown as WebSocket;
|
||||
|
||||
__listenClientTestUtils.stopRuntime(runtime, true);
|
||||
|
||||
expect(runtime.pendingApprovalResolvers.size).toBe(0);
|
||||
expect(socket.removeAllListenersCalls).toBe(1);
|
||||
expect(socket.closeCalls).toBe(1);
|
||||
await expect(pending).rejects.toThrow("Listener runtime stopped");
|
||||
});
|
||||
});
|
||||
|
||||
describe("listen-client requestApprovalOverWS", () => {
|
||||
test("rejects immediately when socket is not open", async () => {
|
||||
const runtime = __listenClientTestUtils.createRuntime();
|
||||
const socket = new MockSocket(WebSocket.CLOSED);
|
||||
const requestId = "perm-closed";
|
||||
|
||||
await expect(
|
||||
requestApprovalOverWS(
|
||||
runtime,
|
||||
socket as unknown as WebSocket,
|
||||
requestId,
|
||||
makeControlRequest(requestId),
|
||||
),
|
||||
).rejects.toThrow("WebSocket not open");
|
||||
expect(runtime.pendingApprovalResolvers.size).toBe(0);
|
||||
});
|
||||
|
||||
test("cleans up resolver when send throws", async () => {
|
||||
const runtime = __listenClientTestUtils.createRuntime();
|
||||
const socket = new MockSocket(WebSocket.OPEN);
|
||||
socket.sendImpl = () => {
|
||||
throw new Error("send failed");
|
||||
};
|
||||
const requestId = "perm-send-fail";
|
||||
|
||||
await expect(
|
||||
requestApprovalOverWS(
|
||||
runtime,
|
||||
socket as unknown as WebSocket,
|
||||
requestId,
|
||||
makeControlRequest(requestId),
|
||||
),
|
||||
).rejects.toThrow("send failed");
|
||||
expect(runtime.pendingApprovalResolvers.size).toBe(0);
|
||||
});
|
||||
});
|
||||
@@ -27,6 +27,7 @@ import { permissionMode } from "../permissions/mode";
|
||||
import { settingsManager } from "../settings-manager";
|
||||
import { isInteractiveApprovalTool } from "../tools/interactivePolicy";
|
||||
import { loadTools } from "../tools/manager";
|
||||
import type { ControlRequest, ControlResponseBody } from "../types/protocol";
|
||||
|
||||
interface StartListenerOptions {
|
||||
connectionId: string;
|
||||
@@ -79,6 +80,11 @@ interface ModeChangeMessage {
|
||||
mode: "default" | "acceptEdits" | "plan" | "bypassPermissions";
|
||||
}
|
||||
|
||||
interface WsControlResponse {
|
||||
type: "control_response";
|
||||
response: ControlResponseBody;
|
||||
}
|
||||
|
||||
interface ModeChangedMessage {
|
||||
type: "mode_changed";
|
||||
mode: "default" | "acceptEdits" | "plan" | "bypassPermissions";
|
||||
@@ -86,13 +92,22 @@ interface ModeChangedMessage {
|
||||
error?: string;
|
||||
}
|
||||
|
||||
type ServerMessage = PongMessage | IncomingMessage | ModeChangeMessage;
|
||||
type ServerMessage =
|
||||
| PongMessage
|
||||
| IncomingMessage
|
||||
| ModeChangeMessage
|
||||
| WsControlResponse;
|
||||
type ClientMessage =
|
||||
| PingMessage
|
||||
| ResultMessage
|
||||
| RunStartedMessage
|
||||
| ModeChangedMessage;
|
||||
|
||||
type PendingApprovalResolver = {
|
||||
resolve: (response: ControlResponseBody) => void;
|
||||
reject: (reason: Error) => void;
|
||||
};
|
||||
|
||||
type ListenerRuntime = {
|
||||
socket: WebSocket | null;
|
||||
heartbeatInterval: NodeJS.Timeout | null;
|
||||
@@ -100,6 +115,7 @@ type ListenerRuntime = {
|
||||
intentionallyClosed: boolean;
|
||||
hasSuccessfulConnection: boolean;
|
||||
messageQueue: Promise<void>;
|
||||
pendingApprovalResolvers: Map<string, PendingApprovalResolver>;
|
||||
};
|
||||
|
||||
type ApprovalSlot =
|
||||
@@ -159,6 +175,7 @@ function createRuntime(): ListenerRuntime {
|
||||
intentionallyClosed: false,
|
||||
hasSuccessfulConnection: false,
|
||||
messageQueue: Promise.resolve(),
|
||||
pendingApprovalResolvers: new Map(),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -179,6 +196,7 @@ function stopRuntime(
|
||||
): void {
|
||||
runtime.intentionallyClosed = true;
|
||||
clearRuntimeTimers(runtime);
|
||||
rejectPendingApprovalResolvers(runtime, "Listener runtime stopped");
|
||||
|
||||
if (!runtime.socket) {
|
||||
return;
|
||||
@@ -200,10 +218,29 @@ function stopRuntime(
|
||||
}
|
||||
}
|
||||
|
||||
function parseServerMessage(data: WebSocket.RawData): ServerMessage | null {
|
||||
function isValidControlResponseBody(
|
||||
value: unknown,
|
||||
): value is ControlResponseBody {
|
||||
if (!value || typeof value !== "object") {
|
||||
return false;
|
||||
}
|
||||
|
||||
const maybeResponse = value as {
|
||||
subtype?: unknown;
|
||||
request_id?: unknown;
|
||||
};
|
||||
return (
|
||||
typeof maybeResponse.subtype === "string" &&
|
||||
typeof maybeResponse.request_id === "string"
|
||||
);
|
||||
}
|
||||
|
||||
export function parseServerMessage(
|
||||
data: WebSocket.RawData,
|
||||
): ServerMessage | null {
|
||||
try {
|
||||
const raw = typeof data === "string" ? data : data.toString();
|
||||
const parsed = JSON.parse(raw) as { type?: string };
|
||||
const parsed = JSON.parse(raw) as { type?: string; response?: unknown };
|
||||
if (
|
||||
parsed.type === "pong" ||
|
||||
parsed.type === "message" ||
|
||||
@@ -211,6 +248,12 @@ function parseServerMessage(data: WebSocket.RawData): ServerMessage | null {
|
||||
) {
|
||||
return parsed as ServerMessage;
|
||||
}
|
||||
if (
|
||||
parsed.type === "control_response" &&
|
||||
isValidControlResponseBody(parsed.response)
|
||||
) {
|
||||
return parsed as ServerMessage;
|
||||
}
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
@@ -223,6 +266,65 @@ function sendClientMessage(socket: WebSocket, payload: ClientMessage): void {
|
||||
}
|
||||
}
|
||||
|
||||
function sendControlMessageOverWebSocket(
|
||||
socket: WebSocket,
|
||||
payload: ControlRequest,
|
||||
): void {
|
||||
// Central hook for protocol-only outbound WS messages so future
|
||||
// filtering/mutation can be added without touching approval flow.
|
||||
socket.send(JSON.stringify(payload));
|
||||
}
|
||||
|
||||
export function resolvePendingApprovalResolver(
|
||||
runtime: ListenerRuntime,
|
||||
response: ControlResponseBody,
|
||||
): boolean {
|
||||
const requestId = response.request_id;
|
||||
if (typeof requestId !== "string" || requestId.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const pending = runtime.pendingApprovalResolvers.get(requestId);
|
||||
if (!pending) {
|
||||
return false;
|
||||
}
|
||||
|
||||
runtime.pendingApprovalResolvers.delete(requestId);
|
||||
pending.resolve(response);
|
||||
return true;
|
||||
}
|
||||
|
||||
export function rejectPendingApprovalResolvers(
|
||||
runtime: ListenerRuntime,
|
||||
reason: string,
|
||||
): void {
|
||||
for (const [, pending] of runtime.pendingApprovalResolvers) {
|
||||
pending.reject(new Error(reason));
|
||||
}
|
||||
runtime.pendingApprovalResolvers.clear();
|
||||
}
|
||||
|
||||
export function requestApprovalOverWS(
|
||||
runtime: ListenerRuntime,
|
||||
socket: WebSocket,
|
||||
requestId: string,
|
||||
controlRequest: ControlRequest,
|
||||
): Promise<ControlResponseBody> {
|
||||
if (socket.readyState !== WebSocket.OPEN) {
|
||||
return Promise.reject(new Error("WebSocket not open"));
|
||||
}
|
||||
|
||||
return new Promise<ControlResponseBody>((resolve, reject) => {
|
||||
runtime.pendingApprovalResolvers.set(requestId, { resolve, reject });
|
||||
try {
|
||||
sendControlMessageOverWebSocket(socket, controlRequest);
|
||||
} catch (error) {
|
||||
runtime.pendingApprovalResolvers.delete(requestId);
|
||||
reject(error instanceof Error ? error : new Error(String(error)));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function buildApprovalExecutionPlan(
|
||||
approvalMessage: ApprovalCreate,
|
||||
pendingApprovals: Array<{
|
||||
@@ -415,6 +517,14 @@ async function connectWithRetry(
|
||||
return;
|
||||
}
|
||||
|
||||
if (parsed.type === "control_response") {
|
||||
if (runtime !== activeRuntime || runtime.intentionallyClosed) {
|
||||
return;
|
||||
}
|
||||
resolvePendingApprovalResolver(runtime, parsed.response);
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle mode change messages immediately (not queued)
|
||||
if (parsed.type === "mode_change") {
|
||||
handleModeChange(parsed, socket);
|
||||
@@ -460,6 +570,7 @@ async function connectWithRetry(
|
||||
|
||||
clearRuntimeTimers(runtime);
|
||||
runtime.socket = null;
|
||||
rejectPendingApprovalResolvers(runtime, "WebSocket disconnected");
|
||||
|
||||
if (runtime.intentionallyClosed) {
|
||||
opts.onDisconnected();
|
||||
@@ -758,3 +869,8 @@ export function stopListenerClient(): void {
|
||||
activeRuntime = null;
|
||||
stopRuntime(runtime, true);
|
||||
}
|
||||
|
||||
export const __listenClientTestUtils = {
|
||||
createRuntime,
|
||||
stopRuntime,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user