From 33db9641e760e8749300f835d6449905e8a958a2 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 10 Feb 2026 19:15:05 -0800 Subject: [PATCH] feat: background pump (#34) Co-authored-by: Jason Carreira <4029756+jasoncarreira@users.noreply.github.com> --- src/session.test.ts | 255 +++++++++++++++++++++++++++++++++++++++++++- src/session.ts | 191 +++++++++++++++++++++++++-------- 2 files changed, 398 insertions(+), 48 deletions(-) diff --git a/src/session.test.ts b/src/session.test.ts index ed366a0..f603d4a 100644 --- a/src/session.test.ts +++ b/src/session.test.ts @@ -1,5 +1,151 @@ import { describe, expect, test } from "bun:test"; import { Session } from "./session.js"; +import type { SDKMessage, WireMessage } from "./types.js"; + +const BUFFER_LIMIT = 100; + +class MockTransport { + writes: unknown[] = []; + private queue: WireMessage[] = []; + private resolvers: Array<(msg: WireMessage | null) => void> = []; + private closed = false; + + async connect(): Promise { + return; + } + + async write(msg: unknown): Promise { + this.writes.push(msg); + } + + async *messages(): AsyncGenerator { + while (true) { + const msg = await this.read(); + if (msg === null) { + return; + } + yield msg; + } + } + + push(msg: WireMessage): void { + if (this.closed) { + return; + } + if (this.resolvers.length > 0) { + const resolve = this.resolvers.shift()!; + resolve(msg); + return; + } + this.queue.push(msg); + } + + close(): void { + this.end(); + } + + end(): void { + if (this.closed) { + return; + } + this.closed = true; + for (const resolve of this.resolvers) { + resolve(null); + } + this.resolvers = []; + } + + private async read(): Promise { + if (this.queue.length > 0) { + return this.queue.shift()!; + } + if (this.closed) { + return null; + } + return new Promise((resolve) => { + this.resolvers.push(resolve); + }); + } +} + +function attachMockTransport(session: Session, transport: MockTransport): void { + (session as unknown as { transport: MockTransport }).transport = transport; +} + +function createInitMessage(): WireMessage { + return { + type: "system", + subtype: "init", + agent_id: "agent-1", + session_id: "session-1", + conversation_id: "conversation-1", + model: "claude-sonnet-4", + tools: ["Bash"], + } as WireMessage; +} + +function createAssistantMessage(index: number): WireMessage { + return { + type: "message", + message_type: "assistant_message", + uuid: `assistant-${index}`, + content: `msg-${index}`, + } as WireMessage; +} + +function createResultMessage(): WireMessage { + return { + type: "result", + subtype: "success", + result: "done", + duration_ms: 1, + conversation_id: "conversation-1", + stop_reason: "end_turn", + } as WireMessage; +} + +function createCanUseToolRequest( + requestId: string, + toolName: string, + input: Record, +): WireMessage { + return { + type: "control_request", + request_id: requestId, + request: { + subtype: "can_use_tool", + tool_name: toolName, + tool_call_id: `${requestId}-tool-call`, + input, + permission_suggestions: [], + blocked_path: null, + }, + } as WireMessage; +} + +function findControlResponseByRequestId( + writes: unknown[], + requestId: string, +): Record | undefined { + return writes.find((msg) => { + const payload = msg as { type?: string; response?: { request_id?: string } }; + return payload.type === "control_response" && payload.response?.request_id === requestId; + }) as Record | undefined; +} + +async function waitFor( + predicate: () => boolean, + timeoutMs = 1000, +): Promise { + const deadline = Date.now() + timeoutMs; + while (Date.now() < deadline) { + if (predicate()) { + return; + } + await new Promise((resolve) => setTimeout(resolve, 5)); + } + throw new Error(`Timed out after ${timeoutMs}ms`); +} describe("Session", () => { describe("handleCanUseTool with bypassPermissions", () => { @@ -134,7 +280,7 @@ describe("Session", () => { test("uses canUseTool callback when provided and not bypassPermissions", async () => { const session = new Session({ permissionMode: "default", - canUseTool: async (toolName, input) => { + canUseTool: async (toolName) => { if (toolName === "Bash") { return { behavior: "allow" }; } @@ -159,4 +305,111 @@ describe("Session", () => { }); }); }); + + describe("background pump parity", () => { + test("handles can_use_tool control requests before stream iteration starts", async () => { + let callbackInvocations = 0; + const session = new Session({ + permissionMode: "default", + canUseTool: () => { + callbackInvocations += 1; + return { behavior: "allow" }; + }, + }); + const transport = new MockTransport(); + attachMockTransport(session, transport); + + try { + transport.push(createInitMessage()); + await session.initialize(); + + transport.push( + createCanUseToolRequest("pre-stream-approval", "Bash", { + command: "pwd", + }), + ); + + await waitFor(() => + findControlResponseByRequestId( + transport.writes, + "pre-stream-approval", + ) !== undefined, + ); + + expect(callbackInvocations).toBe(1); + expect( + findControlResponseByRequestId( + transport.writes, + "pre-stream-approval", + ), + ).toMatchObject({ + type: "control_response", + response: { + subtype: "success", + request_id: "pre-stream-approval", + response: { + behavior: "allow", + }, + }, + }); + } finally { + session.close(); + } + }); + + test("bounds buffered stream messages and drops oldest deterministically", async () => { + const session = new Session({ + permissionMode: "default", + }); + const transport = new MockTransport(); + attachMockTransport(session, transport); + + const assistantCount = BUFFER_LIMIT + 20; + + try { + transport.push(createInitMessage()); + await session.initialize(); + + for (let i = 1; i <= assistantCount; i++) { + transport.push(createAssistantMessage(i)); + } + transport.push(createResultMessage()); + transport.push( + createCanUseToolRequest("post-result-marker", "EnterPlanMode", {}), + ); + + await waitFor(() => + findControlResponseByRequestId( + transport.writes, + "post-result-marker", + ) !== undefined, + ); + + const streamed: SDKMessage[] = []; + for await (const msg of session.stream()) { + streamed.push(msg); + } + + const assistants = streamed.filter( + (msg): msg is Extract => + msg.type === "assistant", + ); + + const expectedAssistantCount = BUFFER_LIMIT - 1; + const expectedFirstAssistantIndex = + assistantCount - expectedAssistantCount + 1; + + expect(assistants.length).toBe(expectedAssistantCount); + expect(assistants[0]?.content).toBe( + `msg-${expectedFirstAssistantIndex}`, + ); + expect(assistants[assistants.length - 1]?.content).toBe( + `msg-${assistantCount}`, + ); + expect(streamed[streamed.length - 1]?.type).toBe("result"); + } finally { + session.close(); + } + }); + }); }); diff --git a/src/session.ts b/src/session.ts index d066d76..8c7ef47 100644 --- a/src/session.ts +++ b/src/session.ts @@ -33,6 +33,8 @@ function sessionLog(tag: string, ...args: unknown[]) { if (process.env.DEBUG_SDK) console.error(`[SDK-Session] [${tag}]`, ...args); } +const MAX_BUFFERED_STREAM_MESSAGES = 100; + export class Session implements AsyncDisposable { private transport: SubprocessTransport; private _agentId: string | null = null; @@ -40,7 +42,11 @@ export class Session implements AsyncDisposable { private _conversationId: string | null = null; private initialized = false; private externalTools: Map = new Map(); - + private streamQueue: SDKMessage[] = []; + private streamResolvers: Array<(msg: SDKMessage | null) => void> = []; + private pumpPromise: Promise | null = null; + private pumpClosed = false; + private droppedStreamMessages = 0; constructor( private options: InternalSessionOptions = {} @@ -79,6 +85,16 @@ export class Session implements AsyncDisposable { sessionLog("init", "waiting for init message from CLI..."); for await (const msg of this.transport.messages()) { sessionLog("init", `received wire message: type=${msg.type}`); + + if (msg.type === "control_request") { + const handled = await this.handleControlRequest(msg as ControlRequest); + if (!handled) { + const wireMsgAny = msg as unknown as Record; + sessionLog("init", `DROPPED unsupported control_request: subtype=${(wireMsgAny.request as Record)?.subtype || "N/A"}`); + } + continue; + } + if (msg.type === "system" && "subtype" in msg && msg.subtype === "init") { const initMsg = msg as WireMessage & { agent_id: string; @@ -91,6 +107,7 @@ export class Session implements AsyncDisposable { this._sessionId = initMsg.session_id; this._conversationId = initMsg.conversation_id; this.initialized = true; + this.startBackgroundPump(); // Register external tools with CLI if (this.externalTools.size > 0) { @@ -160,66 +177,144 @@ export class Session implements AsyncDisposable { async *stream(): AsyncGenerator { const streamStart = Date.now(); let yieldCount = 0; - let dropCount = 0; let gotResult = false; + + this.startBackgroundPump(); sessionLog("stream", `starting stream (agent=${this._agentId}, conversation=${this._conversationId})`); - for await (const wireMsg of this.transport.messages()) { - // Handle CLI → SDK control requests (e.g., can_use_tool, execute_external_tool) - if (wireMsg.type === "control_request") { - const controlReq = wireMsg as ControlRequest; - // Widen to string to allow SDK-extension subtypes not in the protocol union - const subtype: string = controlReq.request.subtype; - sessionLog("stream", `control_request: subtype=${subtype} tool=${(controlReq.request as CanUseToolControlRequest).tool_name || "N/A"}`); - - if (subtype === "can_use_tool") { - await this.handleCanUseTool( - controlReq.request_id, - controlReq.request as CanUseToolControlRequest - ); - continue; - } - if (subtype === "execute_external_tool") { - // SDK extension: not in protocol ControlRequestBody union, extract fields via Record - const rawReq = controlReq.request as Record; - await this.handleExecuteExternalTool( - controlReq.request_id, - { - subtype: "execute_external_tool", - tool_call_id: rawReq.tool_call_id as string, - tool_name: rawReq.tool_name as string, - input: rawReq.input as Record, - } - ); - continue; - } + while (true) { + const sdkMsg = await this.nextBufferedMessage(); + if (!sdkMsg) { + break; } - const sdkMsg = this.transformMessage(wireMsg); - if (sdkMsg) { - yieldCount++; - sessionLog("stream", `yield #${yieldCount}: type=${sdkMsg.type}${sdkMsg.type === "result" ? ` success=${(sdkMsg as SDKResultMessage).success} error=${(sdkMsg as SDKResultMessage).error || "none"}` : ""}`); - yield sdkMsg; + yieldCount++; + sessionLog("stream", `yield #${yieldCount}: type=${sdkMsg.type}${sdkMsg.type === "result" ? ` success=${(sdkMsg as SDKResultMessage).success} error=${(sdkMsg as SDKResultMessage).error || "none"}` : ""}`); + yield sdkMsg; - // Stop on result message - if (sdkMsg.type === "result") { - gotResult = true; - break; - } - } else { - dropCount++; - const wireMsgAny = wireMsg as unknown as Record; - sessionLog("stream", `DROPPED wire message #${dropCount}: type=${wireMsg.type} message_type=${wireMsgAny.message_type || "N/A"} subtype=${wireMsgAny.subtype || "N/A"}`); + // Stop on result message + if (sdkMsg.type === "result") { + gotResult = true; + break; } } const elapsed = Date.now() - streamStart; - sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} dropped=${dropCount} gotResult=${gotResult}`); + sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} dropped=${this.droppedStreamMessages} gotResult=${gotResult}`); if (!gotResult) { - sessionLog("stream", `WARNING: stream ended WITHOUT a result message -- transport may have closed unexpectedly`); + sessionLog("stream", "WARNING: stream ended WITHOUT a result message -- transport may have closed unexpectedly"); } } + private startBackgroundPump(): void { + if (this.pumpPromise) { + return; + } + + this.pumpClosed = false; + this.pumpPromise = this.runBackgroundPump() + .catch((err) => { + sessionLog("pump", `ERROR: ${err instanceof Error ? err.message : String(err)}`); + }) + .finally(() => { + this.pumpClosed = true; + this.resolveAllStreamWaiters(null); + }); + } + + private async runBackgroundPump(): Promise { + sessionLog("pump", "background pump started"); + + for await (const wireMsg of this.transport.messages()) { + if (wireMsg.type === "control_request") { + const handled = await this.handleControlRequest(wireMsg as ControlRequest); + if (!handled) { + const wireMsgAny = wireMsg as unknown as Record; + sessionLog("pump", `DROPPED unsupported control_request: subtype=${(wireMsgAny.request as Record)?.subtype || "N/A"}`); + } + continue; + } + + const sdkMsg = this.transformMessage(wireMsg); + if (sdkMsg) { + this.enqueueStreamMessage(sdkMsg); + } else { + const wireMsgAny = wireMsg as unknown as Record; + sessionLog("pump", `DROPPED wire message: type=${wireMsg.type} message_type=${wireMsgAny.message_type || "N/A"} subtype=${wireMsgAny.subtype || "N/A"}`); + } + } + + sessionLog("pump", "background pump ended"); + } + + private async handleControlRequest(controlReq: ControlRequest): Promise { + // Widen to string to allow SDK-extension subtypes not in the protocol union + const subtype: string = controlReq.request.subtype; + sessionLog("pump", `control_request: subtype=${subtype} tool=${(controlReq.request as CanUseToolControlRequest).tool_name || "N/A"}`); + + if (subtype === "can_use_tool") { + await this.handleCanUseTool( + controlReq.request_id, + controlReq.request as CanUseToolControlRequest + ); + return true; + } + + if (subtype === "execute_external_tool") { + // SDK extension: not in protocol ControlRequestBody union, extract fields via Record + const rawReq = controlReq.request as Record; + await this.handleExecuteExternalTool( + controlReq.request_id, + { + subtype: "execute_external_tool", + tool_call_id: rawReq.tool_call_id as string, + tool_name: rawReq.tool_name as string, + input: rawReq.input as Record, + } + ); + return true; + } + + return false; + } + + private enqueueStreamMessage(msg: SDKMessage): void { + if (this.streamResolvers.length > 0) { + const resolve = this.streamResolvers.shift()!; + resolve(msg); + return; + } + + if (this.streamQueue.length >= MAX_BUFFERED_STREAM_MESSAGES) { + this.streamQueue.shift(); + this.droppedStreamMessages++; + sessionLog("pump", `stream queue overflow: dropped oldest message (total_dropped=${this.droppedStreamMessages}, max=${MAX_BUFFERED_STREAM_MESSAGES})`); + } + + this.streamQueue.push(msg); + } + + private async nextBufferedMessage(): Promise { + if (this.streamQueue.length > 0) { + return this.streamQueue.shift()!; + } + + if (this.pumpClosed) { + return null; + } + + return new Promise((resolve) => { + this.streamResolvers.push(resolve); + }); + } + + private resolveAllStreamWaiters(msg: SDKMessage | null): void { + for (const resolve of this.streamResolvers) { + resolve(msg); + } + this.streamResolvers = []; + } + /** * Register external tools with the CLI */ @@ -430,6 +525,8 @@ export class Session implements AsyncDisposable { close(): void { sessionLog("close", `closing session (agent=${this._agentId}, conversation=${this._conversationId})`); this.transport.close(); + this.pumpClosed = true; + this.resolveAllStreamWaiters(null); } /**