feat: background pump (#34)

Co-authored-by: Jason Carreira <4029756+jasoncarreira@users.noreply.github.com>
This commit is contained in:
Charles Packer
2026-02-10 19:15:05 -08:00
committed by GitHub
parent 6cbe132f8c
commit 33db9641e7
2 changed files with 398 additions and 48 deletions

View File

@@ -1,5 +1,151 @@
import { describe, expect, test } from "bun:test";
import { Session } from "./session.js";
import type { SDKMessage, WireMessage } from "./types.js";
const BUFFER_LIMIT = 100;
class MockTransport {
writes: unknown[] = [];
private queue: WireMessage[] = [];
private resolvers: Array<(msg: WireMessage | null) => void> = [];
private closed = false;
async connect(): Promise<void> {
return;
}
async write(msg: unknown): Promise<void> {
this.writes.push(msg);
}
async *messages(): AsyncGenerator<WireMessage> {
while (true) {
const msg = await this.read();
if (msg === null) {
return;
}
yield msg;
}
}
push(msg: WireMessage): void {
if (this.closed) {
return;
}
if (this.resolvers.length > 0) {
const resolve = this.resolvers.shift()!;
resolve(msg);
return;
}
this.queue.push(msg);
}
close(): void {
this.end();
}
end(): void {
if (this.closed) {
return;
}
this.closed = true;
for (const resolve of this.resolvers) {
resolve(null);
}
this.resolvers = [];
}
private async read(): Promise<WireMessage | null> {
if (this.queue.length > 0) {
return this.queue.shift()!;
}
if (this.closed) {
return null;
}
return new Promise((resolve) => {
this.resolvers.push(resolve);
});
}
}
function attachMockTransport(session: Session, transport: MockTransport): void {
(session as unknown as { transport: MockTransport }).transport = transport;
}
function createInitMessage(): WireMessage {
return {
type: "system",
subtype: "init",
agent_id: "agent-1",
session_id: "session-1",
conversation_id: "conversation-1",
model: "claude-sonnet-4",
tools: ["Bash"],
} as WireMessage;
}
function createAssistantMessage(index: number): WireMessage {
return {
type: "message",
message_type: "assistant_message",
uuid: `assistant-${index}`,
content: `msg-${index}`,
} as WireMessage;
}
function createResultMessage(): WireMessage {
return {
type: "result",
subtype: "success",
result: "done",
duration_ms: 1,
conversation_id: "conversation-1",
stop_reason: "end_turn",
} as WireMessage;
}
function createCanUseToolRequest(
requestId: string,
toolName: string,
input: Record<string, unknown>,
): WireMessage {
return {
type: "control_request",
request_id: requestId,
request: {
subtype: "can_use_tool",
tool_name: toolName,
tool_call_id: `${requestId}-tool-call`,
input,
permission_suggestions: [],
blocked_path: null,
},
} as WireMessage;
}
function findControlResponseByRequestId(
writes: unknown[],
requestId: string,
): Record<string, unknown> | undefined {
return writes.find((msg) => {
const payload = msg as { type?: string; response?: { request_id?: string } };
return payload.type === "control_response" && payload.response?.request_id === requestId;
}) as Record<string, unknown> | undefined;
}
async function waitFor(
predicate: () => boolean,
timeoutMs = 1000,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (predicate()) {
return;
}
await new Promise((resolve) => setTimeout(resolve, 5));
}
throw new Error(`Timed out after ${timeoutMs}ms`);
}
describe("Session", () => {
describe("handleCanUseTool with bypassPermissions", () => {
@@ -134,7 +280,7 @@ describe("Session", () => {
test("uses canUseTool callback when provided and not bypassPermissions", async () => {
const session = new Session({
permissionMode: "default",
canUseTool: async (toolName, input) => {
canUseTool: async (toolName) => {
if (toolName === "Bash") {
return { behavior: "allow" };
}
@@ -159,4 +305,111 @@ describe("Session", () => {
});
});
});
describe("background pump parity", () => {
test("handles can_use_tool control requests before stream iteration starts", async () => {
let callbackInvocations = 0;
const session = new Session({
permissionMode: "default",
canUseTool: () => {
callbackInvocations += 1;
return { behavior: "allow" };
},
});
const transport = new MockTransport();
attachMockTransport(session, transport);
try {
transport.push(createInitMessage());
await session.initialize();
transport.push(
createCanUseToolRequest("pre-stream-approval", "Bash", {
command: "pwd",
}),
);
await waitFor(() =>
findControlResponseByRequestId(
transport.writes,
"pre-stream-approval",
) !== undefined,
);
expect(callbackInvocations).toBe(1);
expect(
findControlResponseByRequestId(
transport.writes,
"pre-stream-approval",
),
).toMatchObject({
type: "control_response",
response: {
subtype: "success",
request_id: "pre-stream-approval",
response: {
behavior: "allow",
},
},
});
} finally {
session.close();
}
});
test("bounds buffered stream messages and drops oldest deterministically", async () => {
const session = new Session({
permissionMode: "default",
});
const transport = new MockTransport();
attachMockTransport(session, transport);
const assistantCount = BUFFER_LIMIT + 20;
try {
transport.push(createInitMessage());
await session.initialize();
for (let i = 1; i <= assistantCount; i++) {
transport.push(createAssistantMessage(i));
}
transport.push(createResultMessage());
transport.push(
createCanUseToolRequest("post-result-marker", "EnterPlanMode", {}),
);
await waitFor(() =>
findControlResponseByRequestId(
transport.writes,
"post-result-marker",
) !== undefined,
);
const streamed: SDKMessage[] = [];
for await (const msg of session.stream()) {
streamed.push(msg);
}
const assistants = streamed.filter(
(msg): msg is Extract<SDKMessage, { type: "assistant" }> =>
msg.type === "assistant",
);
const expectedAssistantCount = BUFFER_LIMIT - 1;
const expectedFirstAssistantIndex =
assistantCount - expectedAssistantCount + 1;
expect(assistants.length).toBe(expectedAssistantCount);
expect(assistants[0]?.content).toBe(
`msg-${expectedFirstAssistantIndex}`,
);
expect(assistants[assistants.length - 1]?.content).toBe(
`msg-${assistantCount}`,
);
expect(streamed[streamed.length - 1]?.type).toBe("result");
} finally {
session.close();
}
});
});
});

View File

@@ -33,6 +33,8 @@ function sessionLog(tag: string, ...args: unknown[]) {
if (process.env.DEBUG_SDK) console.error(`[SDK-Session] [${tag}]`, ...args);
}
const MAX_BUFFERED_STREAM_MESSAGES = 100;
export class Session implements AsyncDisposable {
private transport: SubprocessTransport;
private _agentId: string | null = null;
@@ -40,7 +42,11 @@ export class Session implements AsyncDisposable {
private _conversationId: string | null = null;
private initialized = false;
private externalTools: Map<string, AnyAgentTool> = new Map();
private streamQueue: SDKMessage[] = [];
private streamResolvers: Array<(msg: SDKMessage | null) => void> = [];
private pumpPromise: Promise<void> | null = null;
private pumpClosed = false;
private droppedStreamMessages = 0;
constructor(
private options: InternalSessionOptions = {}
@@ -79,6 +85,16 @@ export class Session implements AsyncDisposable {
sessionLog("init", "waiting for init message from CLI...");
for await (const msg of this.transport.messages()) {
sessionLog("init", `received wire message: type=${msg.type}`);
if (msg.type === "control_request") {
const handled = await this.handleControlRequest(msg as ControlRequest);
if (!handled) {
const wireMsgAny = msg as unknown as Record<string, unknown>;
sessionLog("init", `DROPPED unsupported control_request: subtype=${(wireMsgAny.request as Record<string, unknown>)?.subtype || "N/A"}`);
}
continue;
}
if (msg.type === "system" && "subtype" in msg && msg.subtype === "init") {
const initMsg = msg as WireMessage & {
agent_id: string;
@@ -91,6 +107,7 @@ export class Session implements AsyncDisposable {
this._sessionId = initMsg.session_id;
this._conversationId = initMsg.conversation_id;
this.initialized = true;
this.startBackgroundPump();
// Register external tools with CLI
if (this.externalTools.size > 0) {
@@ -160,66 +177,144 @@ export class Session implements AsyncDisposable {
async *stream(): AsyncGenerator<SDKMessage> {
const streamStart = Date.now();
let yieldCount = 0;
let dropCount = 0;
let gotResult = false;
this.startBackgroundPump();
sessionLog("stream", `starting stream (agent=${this._agentId}, conversation=${this._conversationId})`);
for await (const wireMsg of this.transport.messages()) {
// Handle CLI → SDK control requests (e.g., can_use_tool, execute_external_tool)
if (wireMsg.type === "control_request") {
const controlReq = wireMsg as ControlRequest;
// Widen to string to allow SDK-extension subtypes not in the protocol union
const subtype: string = controlReq.request.subtype;
sessionLog("stream", `control_request: subtype=${subtype} tool=${(controlReq.request as CanUseToolControlRequest).tool_name || "N/A"}`);
if (subtype === "can_use_tool") {
await this.handleCanUseTool(
controlReq.request_id,
controlReq.request as CanUseToolControlRequest
);
continue;
}
if (subtype === "execute_external_tool") {
// SDK extension: not in protocol ControlRequestBody union, extract fields via Record
const rawReq = controlReq.request as Record<string, unknown>;
await this.handleExecuteExternalTool(
controlReq.request_id,
{
subtype: "execute_external_tool",
tool_call_id: rawReq.tool_call_id as string,
tool_name: rawReq.tool_name as string,
input: rawReq.input as Record<string, unknown>,
}
);
continue;
}
while (true) {
const sdkMsg = await this.nextBufferedMessage();
if (!sdkMsg) {
break;
}
const sdkMsg = this.transformMessage(wireMsg);
if (sdkMsg) {
yieldCount++;
sessionLog("stream", `yield #${yieldCount}: type=${sdkMsg.type}${sdkMsg.type === "result" ? ` success=${(sdkMsg as SDKResultMessage).success} error=${(sdkMsg as SDKResultMessage).error || "none"}` : ""}`);
yield sdkMsg;
yieldCount++;
sessionLog("stream", `yield #${yieldCount}: type=${sdkMsg.type}${sdkMsg.type === "result" ? ` success=${(sdkMsg as SDKResultMessage).success} error=${(sdkMsg as SDKResultMessage).error || "none"}` : ""}`);
yield sdkMsg;
// Stop on result message
if (sdkMsg.type === "result") {
gotResult = true;
break;
}
} else {
dropCount++;
const wireMsgAny = wireMsg as unknown as Record<string, unknown>;
sessionLog("stream", `DROPPED wire message #${dropCount}: type=${wireMsg.type} message_type=${wireMsgAny.message_type || "N/A"} subtype=${wireMsgAny.subtype || "N/A"}`);
// Stop on result message
if (sdkMsg.type === "result") {
gotResult = true;
break;
}
}
const elapsed = Date.now() - streamStart;
sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} dropped=${dropCount} gotResult=${gotResult}`);
sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} dropped=${this.droppedStreamMessages} gotResult=${gotResult}`);
if (!gotResult) {
sessionLog("stream", `WARNING: stream ended WITHOUT a result message -- transport may have closed unexpectedly`);
sessionLog("stream", "WARNING: stream ended WITHOUT a result message -- transport may have closed unexpectedly");
}
}
private startBackgroundPump(): void {
if (this.pumpPromise) {
return;
}
this.pumpClosed = false;
this.pumpPromise = this.runBackgroundPump()
.catch((err) => {
sessionLog("pump", `ERROR: ${err instanceof Error ? err.message : String(err)}`);
})
.finally(() => {
this.pumpClosed = true;
this.resolveAllStreamWaiters(null);
});
}
private async runBackgroundPump(): Promise<void> {
sessionLog("pump", "background pump started");
for await (const wireMsg of this.transport.messages()) {
if (wireMsg.type === "control_request") {
const handled = await this.handleControlRequest(wireMsg as ControlRequest);
if (!handled) {
const wireMsgAny = wireMsg as unknown as Record<string, unknown>;
sessionLog("pump", `DROPPED unsupported control_request: subtype=${(wireMsgAny.request as Record<string, unknown>)?.subtype || "N/A"}`);
}
continue;
}
const sdkMsg = this.transformMessage(wireMsg);
if (sdkMsg) {
this.enqueueStreamMessage(sdkMsg);
} else {
const wireMsgAny = wireMsg as unknown as Record<string, unknown>;
sessionLog("pump", `DROPPED wire message: type=${wireMsg.type} message_type=${wireMsgAny.message_type || "N/A"} subtype=${wireMsgAny.subtype || "N/A"}`);
}
}
sessionLog("pump", "background pump ended");
}
private async handleControlRequest(controlReq: ControlRequest): Promise<boolean> {
// Widen to string to allow SDK-extension subtypes not in the protocol union
const subtype: string = controlReq.request.subtype;
sessionLog("pump", `control_request: subtype=${subtype} tool=${(controlReq.request as CanUseToolControlRequest).tool_name || "N/A"}`);
if (subtype === "can_use_tool") {
await this.handleCanUseTool(
controlReq.request_id,
controlReq.request as CanUseToolControlRequest
);
return true;
}
if (subtype === "execute_external_tool") {
// SDK extension: not in protocol ControlRequestBody union, extract fields via Record
const rawReq = controlReq.request as Record<string, unknown>;
await this.handleExecuteExternalTool(
controlReq.request_id,
{
subtype: "execute_external_tool",
tool_call_id: rawReq.tool_call_id as string,
tool_name: rawReq.tool_name as string,
input: rawReq.input as Record<string, unknown>,
}
);
return true;
}
return false;
}
private enqueueStreamMessage(msg: SDKMessage): void {
if (this.streamResolvers.length > 0) {
const resolve = this.streamResolvers.shift()!;
resolve(msg);
return;
}
if (this.streamQueue.length >= MAX_BUFFERED_STREAM_MESSAGES) {
this.streamQueue.shift();
this.droppedStreamMessages++;
sessionLog("pump", `stream queue overflow: dropped oldest message (total_dropped=${this.droppedStreamMessages}, max=${MAX_BUFFERED_STREAM_MESSAGES})`);
}
this.streamQueue.push(msg);
}
private async nextBufferedMessage(): Promise<SDKMessage | null> {
if (this.streamQueue.length > 0) {
return this.streamQueue.shift()!;
}
if (this.pumpClosed) {
return null;
}
return new Promise((resolve) => {
this.streamResolvers.push(resolve);
});
}
private resolveAllStreamWaiters(msg: SDKMessage | null): void {
for (const resolve of this.streamResolvers) {
resolve(msg);
}
this.streamResolvers = [];
}
/**
* Register external tools with the CLI
*/
@@ -430,6 +525,8 @@ export class Session implements AsyncDisposable {
close(): void {
sessionLog("close", `closing session (agent=${this._agentId}, conversation=${this._conversationId})`);
this.transport.close();
this.pumpClosed = true;
this.resolveAllStreamWaiters(null);
}
/**