diff --git a/src/tests/websocket/listen-client-protocol.test.ts b/src/tests/websocket/listen-client-protocol.test.ts index 056e2d5..74a7842 100644 --- a/src/tests/websocket/listen-client-protocol.test.ts +++ b/src/tests/websocket/listen-client-protocol.test.ts @@ -237,3 +237,142 @@ describe("listen-client requestApprovalOverWS", () => { expect(runtime.pendingApprovalResolvers.size).toBe(0); }); }); + +describe("listen-client controlResponseCapable latch", () => { + test("runtime initializes with controlResponseCapable = false", () => { + const runtime = __listenClientTestUtils.createRuntime(); + expect(runtime.controlResponseCapable).toBe(false); + }); + + test("latch stays true after being set once", () => { + const runtime = __listenClientTestUtils.createRuntime(); + expect(runtime.controlResponseCapable).toBe(false); + + runtime.controlResponseCapable = true; + expect(runtime.controlResponseCapable).toBe(true); + + // Simulates second message without the flag — latch should persist + // (actual latching happens in handleIncomingMessage, but the runtime + // field itself should hold the value) + expect(runtime.controlResponseCapable).toBe(true); + }); +}); + +describe("listen-client capability-gated approval flow", () => { + test("control_response with allow + updatedInput rewrites tool args", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.OPEN); + const requestId = "perm-update-test"; + + const pending = requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + requestId, + makeControlRequest(requestId), + ); + + // Simulate control_response with updatedInput + resolvePendingApprovalResolver(runtime, { + subtype: "success", + request_id: requestId, + response: { + behavior: "allow", + updatedInput: { file_path: "/updated/path.ts", content: "new content" }, + }, + }); + + const response = await pending; + expect(response.subtype).toBe("success"); + if (response.subtype === "success") { + const canUseToolResponse = response.response as { + behavior: string; + updatedInput?: Record; + }; + expect(canUseToolResponse.behavior).toBe("allow"); + expect(canUseToolResponse.updatedInput).toEqual({ + file_path: "/updated/path.ts", + content: "new content", + }); + } + }); + + test("control_response with deny includes reason", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.OPEN); + const requestId = "perm-deny-test"; + + const pending = requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + requestId, + makeControlRequest(requestId), + ); + + resolvePendingApprovalResolver(runtime, { + subtype: "success", + request_id: requestId, + response: { behavior: "deny", message: "User declined" }, + }); + + const response = await pending; + expect(response.subtype).toBe("success"); + if (response.subtype === "success") { + const canUseToolResponse = response.response as { + behavior: string; + message?: string; + }; + expect(canUseToolResponse.behavior).toBe("deny"); + expect(canUseToolResponse.message).toBe("User declined"); + } + }); + + test("error response from WS triggers denial path", async () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.OPEN); + const requestId = "perm-error-test"; + + const pending = requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + requestId, + makeControlRequest(requestId), + ); + + resolvePendingApprovalResolver(runtime, { + subtype: "error", + request_id: requestId, + error: "Internal server error", + }); + + const response = await pending; + expect(response.subtype).toBe("error"); + if (response.subtype === "error") { + expect(response.error).toBe("Internal server error"); + } + }); + + test("outbound control_request is sent through sendControlMessageOverWebSocket (not raw socket.send)", () => { + const runtime = __listenClientTestUtils.createRuntime(); + const socket = new MockSocket(WebSocket.OPEN); + const requestId = "perm-adapter-test"; + + // requestApprovalOverWS uses sendControlMessageOverWebSocket internally + // which ultimately calls socket.send — but goes through the adapter stub. + // We verify the message was sent with the correct shape. + void requestApprovalOverWS( + runtime, + socket as unknown as WebSocket, + requestId, + makeControlRequest(requestId), + ).catch(() => {}); + + expect(socket.sentPayloads).toHaveLength(1); + const sent = JSON.parse(socket.sentPayloads[0] as string); + expect(sent.type).toBe("control_request"); + expect(sent.request_id).toBe(requestId); + expect(sent.request.subtype).toBe("can_use_tool"); + + // Cleanup + rejectPendingApprovalResolvers(runtime, "test cleanup"); + }); +}); diff --git a/src/websocket/listen-client.ts b/src/websocket/listen-client.ts index bce93d4..84842a3 100644 --- a/src/websocket/listen-client.ts +++ b/src/websocket/listen-client.ts @@ -23,11 +23,16 @@ import { createBuffers } from "../cli/helpers/accumulator"; import { classifyApprovals } from "../cli/helpers/approvalClassification"; import { generatePlanFilePath } from "../cli/helpers/planName"; import { drainStreamWithResume } from "../cli/helpers/stream"; +import { computeDiffPreviews } from "../helpers/diffPreview"; import { permissionMode } from "../permissions/mode"; import { settingsManager } from "../settings-manager"; import { isInteractiveApprovalTool } from "../tools/interactivePolicy"; import { loadTools } from "../tools/manager"; -import type { ControlRequest, ControlResponseBody } from "../types/protocol"; +import type { + CanUseToolResponse, + ControlRequest, + ControlResponseBody, +} from "../types/protocol"; interface StartListenerOptions { connectionId: string; @@ -62,6 +67,8 @@ interface IncomingMessage { agentId?: string; conversationId?: string; messages: Array; + /** Cloud sets this when it supports can_use_tool / control_response protocol. */ + supportsControlResponse?: boolean; } interface ResultMessage { @@ -116,6 +123,8 @@ type ListenerRuntime = { hasSuccessfulConnection: boolean; messageQueue: Promise; pendingApprovalResolvers: Map; + /** Latched once supportsControlResponse is seen on any message. */ + controlResponseCapable: boolean; }; type ApprovalSlot = @@ -176,6 +185,7 @@ function createRuntime(): ListenerRuntime { hasSuccessfulConnection: false, messageQueue: Promise.resolve(), pendingApprovalResolvers: new Map(), + controlResponseCapable: false, }; } @@ -543,6 +553,7 @@ async function connectWithRetry( await handleIncomingMessage( parsed, socket, + runtime, opts.onStatusChange, opts.connectionId, ); @@ -605,6 +616,7 @@ async function connectWithRetry( async function handleIncomingMessage( msg: IncomingMessage, socket: WebSocket, + runtime: ListenerRuntime, onStatusChange?: ( status: "idle" | "receiving" | "processing", connectionId: string, @@ -612,6 +624,11 @@ async function handleIncomingMessage( connectionId?: string, ): Promise { try { + // Latch capability: once seen, always use blocking path (strict check to avoid truthy strings) + if (msg.supportsControlResponse === true) { + runtime.controlResponseCapable = true; + } + const agentId = msg.agentId; // requestedConversationId can be: // - undefined: no conversation (use agent endpoint) @@ -646,6 +663,12 @@ async function handleIncomingMessage( "approvals" in firstMessage; if (isApprovalMessage) { + if (runtime.controlResponseCapable && process.env.DEBUG) { + console.warn( + "[Listen] Protocol violation: controlResponseCapable is latched but received legacy ApprovalCreate message. " + + "The cloud should send control_response instead. This may cause the current turn to stall.", + ); + } const approvalMessage = firstMessage as ApprovalCreate; const client = await getClient(); const agent = await client.agents.retrieve(agentId); @@ -694,7 +717,7 @@ async function handleIncomingMessage( ]; } - const stream = await sendMessageStream(conversationId, messagesToSend, { + let stream = await sendMessageStream(conversationId, messagesToSend, { agentId, streamTokens: true, background: true, @@ -769,21 +792,7 @@ async function handleIncomingMessage( requireArgsForAutoApprove: true, }); - // If there are approvals that need user input, pause execution - // Cloud UI will see pending approvals via /v1/runs/:runId/stream from core - // and show approval dialog. When user approves, cloud sends approval message - // back to this device, which resumes execution. - if (needsUserInput.length > 0) { - sendClientMessage(socket, { - type: "result", - success: false, - stopReason: "requires_approval", - }); - break; // Exit loop - cloud will send approval message when user approves - } - - // Only auto-allowed and auto-denied tools remain - // Build decisions list + // Build decisions list (before needsUserInput gate so both paths accumulate here) type Decision = | { type: "approve"; @@ -815,11 +824,84 @@ async function handleIncomingMessage( })), ]; + // Handle tools that need user input + if (needsUserInput.length > 0) { + if (!runtime.controlResponseCapable) { + // Legacy path: break out, let cloud re-enter with ApprovalCreate + sendClientMessage(socket, { + type: "result", + success: false, + stopReason: "requires_approval", + }); + break; + } + + // New path: blocking-in-loop via WS control protocol + for (const ac of needsUserInput) { + const requestId = `perm-${ac.approval.toolCallId}`; + const diffs = await computeDiffPreviews( + ac.approval.toolName, + ac.parsedArgs, + ); + + const controlRequest: ControlRequest = { + type: "control_request", + request_id: requestId, + request: { + subtype: "can_use_tool", + tool_name: ac.approval.toolName, + input: ac.parsedArgs, + tool_call_id: ac.approval.toolCallId, + permission_suggestions: [], + blocked_path: null, + ...(diffs.length > 0 ? { diffs } : {}), + }, + }; + + const responseBody = await requestApprovalOverWS( + runtime, + socket, + requestId, + controlRequest, + ); + + if (responseBody.subtype === "success") { + const response = responseBody.response as + | CanUseToolResponse + | undefined; + if (response?.behavior === "allow") { + const finalApproval = response.updatedInput + ? { + ...ac.approval, + toolArgs: JSON.stringify(response.updatedInput), + } + : ac.approval; + decisions.push({ type: "approve", approval: finalApproval }); + } else { + decisions.push({ + type: "deny", + approval: ac.approval, + reason: response?.message || "Denied via WebSocket", + }); + } + } else { + decisions.push({ + type: "deny", + approval: ac.approval, + reason: + responseBody.subtype === "error" + ? responseBody.error + : "Unknown error", + }); + } + } + } + // Execute approved/denied tools const executionResults = await executeApprovalBatch(decisions); - // Send approval message back to agent to continue execution - const approvalStream = await sendMessageStream( + // Create fresh approval stream for next iteration + stream = await sendMessageStream( conversationId, [ { @@ -833,9 +915,6 @@ async function handleIncomingMessage( background: true, }, ); - - // Replace stream with approval stream for next iteration - Object.assign(stream, approvalStream); } } catch (error) { sendClientMessage(socket, {