feat: background pump (#34)
Co-authored-by: Jason Carreira <4029756+jasoncarreira@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,151 @@
|
||||
import { describe, expect, test } from "bun:test";
|
||||
import { Session } from "./session.js";
|
||||
import type { SDKMessage, WireMessage } from "./types.js";
|
||||
|
||||
const BUFFER_LIMIT = 100;
|
||||
|
||||
class MockTransport {
|
||||
writes: unknown[] = [];
|
||||
private queue: WireMessage[] = [];
|
||||
private resolvers: Array<(msg: WireMessage | null) => void> = [];
|
||||
private closed = false;
|
||||
|
||||
async connect(): Promise<void> {
|
||||
return;
|
||||
}
|
||||
|
||||
async write(msg: unknown): Promise<void> {
|
||||
this.writes.push(msg);
|
||||
}
|
||||
|
||||
async *messages(): AsyncGenerator<WireMessage> {
|
||||
while (true) {
|
||||
const msg = await this.read();
|
||||
if (msg === null) {
|
||||
return;
|
||||
}
|
||||
yield msg;
|
||||
}
|
||||
}
|
||||
|
||||
push(msg: WireMessage): void {
|
||||
if (this.closed) {
|
||||
return;
|
||||
}
|
||||
if (this.resolvers.length > 0) {
|
||||
const resolve = this.resolvers.shift()!;
|
||||
resolve(msg);
|
||||
return;
|
||||
}
|
||||
this.queue.push(msg);
|
||||
}
|
||||
|
||||
close(): void {
|
||||
this.end();
|
||||
}
|
||||
|
||||
end(): void {
|
||||
if (this.closed) {
|
||||
return;
|
||||
}
|
||||
this.closed = true;
|
||||
for (const resolve of this.resolvers) {
|
||||
resolve(null);
|
||||
}
|
||||
this.resolvers = [];
|
||||
}
|
||||
|
||||
private async read(): Promise<WireMessage | null> {
|
||||
if (this.queue.length > 0) {
|
||||
return this.queue.shift()!;
|
||||
}
|
||||
if (this.closed) {
|
||||
return null;
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
this.resolvers.push(resolve);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function attachMockTransport(session: Session, transport: MockTransport): void {
|
||||
(session as unknown as { transport: MockTransport }).transport = transport;
|
||||
}
|
||||
|
||||
function createInitMessage(): WireMessage {
|
||||
return {
|
||||
type: "system",
|
||||
subtype: "init",
|
||||
agent_id: "agent-1",
|
||||
session_id: "session-1",
|
||||
conversation_id: "conversation-1",
|
||||
model: "claude-sonnet-4",
|
||||
tools: ["Bash"],
|
||||
} as WireMessage;
|
||||
}
|
||||
|
||||
function createAssistantMessage(index: number): WireMessage {
|
||||
return {
|
||||
type: "message",
|
||||
message_type: "assistant_message",
|
||||
uuid: `assistant-${index}`,
|
||||
content: `msg-${index}`,
|
||||
} as WireMessage;
|
||||
}
|
||||
|
||||
function createResultMessage(): WireMessage {
|
||||
return {
|
||||
type: "result",
|
||||
subtype: "success",
|
||||
result: "done",
|
||||
duration_ms: 1,
|
||||
conversation_id: "conversation-1",
|
||||
stop_reason: "end_turn",
|
||||
} as WireMessage;
|
||||
}
|
||||
|
||||
function createCanUseToolRequest(
|
||||
requestId: string,
|
||||
toolName: string,
|
||||
input: Record<string, unknown>,
|
||||
): WireMessage {
|
||||
return {
|
||||
type: "control_request",
|
||||
request_id: requestId,
|
||||
request: {
|
||||
subtype: "can_use_tool",
|
||||
tool_name: toolName,
|
||||
tool_call_id: `${requestId}-tool-call`,
|
||||
input,
|
||||
permission_suggestions: [],
|
||||
blocked_path: null,
|
||||
},
|
||||
} as WireMessage;
|
||||
}
|
||||
|
||||
function findControlResponseByRequestId(
|
||||
writes: unknown[],
|
||||
requestId: string,
|
||||
): Record<string, unknown> | undefined {
|
||||
return writes.find((msg) => {
|
||||
const payload = msg as { type?: string; response?: { request_id?: string } };
|
||||
return payload.type === "control_response" && payload.response?.request_id === requestId;
|
||||
}) as Record<string, unknown> | undefined;
|
||||
}
|
||||
|
||||
async function waitFor(
|
||||
predicate: () => boolean,
|
||||
timeoutMs = 1000,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (predicate()) {
|
||||
return;
|
||||
}
|
||||
await new Promise((resolve) => setTimeout(resolve, 5));
|
||||
}
|
||||
throw new Error(`Timed out after ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
describe("Session", () => {
|
||||
describe("handleCanUseTool with bypassPermissions", () => {
|
||||
@@ -134,7 +280,7 @@ describe("Session", () => {
|
||||
test("uses canUseTool callback when provided and not bypassPermissions", async () => {
|
||||
const session = new Session({
|
||||
permissionMode: "default",
|
||||
canUseTool: async (toolName, input) => {
|
||||
canUseTool: async (toolName) => {
|
||||
if (toolName === "Bash") {
|
||||
return { behavior: "allow" };
|
||||
}
|
||||
@@ -159,4 +305,111 @@ describe("Session", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("background pump parity", () => {
|
||||
test("handles can_use_tool control requests before stream iteration starts", async () => {
|
||||
let callbackInvocations = 0;
|
||||
const session = new Session({
|
||||
permissionMode: "default",
|
||||
canUseTool: () => {
|
||||
callbackInvocations += 1;
|
||||
return { behavior: "allow" };
|
||||
},
|
||||
});
|
||||
const transport = new MockTransport();
|
||||
attachMockTransport(session, transport);
|
||||
|
||||
try {
|
||||
transport.push(createInitMessage());
|
||||
await session.initialize();
|
||||
|
||||
transport.push(
|
||||
createCanUseToolRequest("pre-stream-approval", "Bash", {
|
||||
command: "pwd",
|
||||
}),
|
||||
);
|
||||
|
||||
await waitFor(() =>
|
||||
findControlResponseByRequestId(
|
||||
transport.writes,
|
||||
"pre-stream-approval",
|
||||
) !== undefined,
|
||||
);
|
||||
|
||||
expect(callbackInvocations).toBe(1);
|
||||
expect(
|
||||
findControlResponseByRequestId(
|
||||
transport.writes,
|
||||
"pre-stream-approval",
|
||||
),
|
||||
).toMatchObject({
|
||||
type: "control_response",
|
||||
response: {
|
||||
subtype: "success",
|
||||
request_id: "pre-stream-approval",
|
||||
response: {
|
||||
behavior: "allow",
|
||||
},
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
session.close();
|
||||
}
|
||||
});
|
||||
|
||||
test("bounds buffered stream messages and drops oldest deterministically", async () => {
|
||||
const session = new Session({
|
||||
permissionMode: "default",
|
||||
});
|
||||
const transport = new MockTransport();
|
||||
attachMockTransport(session, transport);
|
||||
|
||||
const assistantCount = BUFFER_LIMIT + 20;
|
||||
|
||||
try {
|
||||
transport.push(createInitMessage());
|
||||
await session.initialize();
|
||||
|
||||
for (let i = 1; i <= assistantCount; i++) {
|
||||
transport.push(createAssistantMessage(i));
|
||||
}
|
||||
transport.push(createResultMessage());
|
||||
transport.push(
|
||||
createCanUseToolRequest("post-result-marker", "EnterPlanMode", {}),
|
||||
);
|
||||
|
||||
await waitFor(() =>
|
||||
findControlResponseByRequestId(
|
||||
transport.writes,
|
||||
"post-result-marker",
|
||||
) !== undefined,
|
||||
);
|
||||
|
||||
const streamed: SDKMessage[] = [];
|
||||
for await (const msg of session.stream()) {
|
||||
streamed.push(msg);
|
||||
}
|
||||
|
||||
const assistants = streamed.filter(
|
||||
(msg): msg is Extract<SDKMessage, { type: "assistant" }> =>
|
||||
msg.type === "assistant",
|
||||
);
|
||||
|
||||
const expectedAssistantCount = BUFFER_LIMIT - 1;
|
||||
const expectedFirstAssistantIndex =
|
||||
assistantCount - expectedAssistantCount + 1;
|
||||
|
||||
expect(assistants.length).toBe(expectedAssistantCount);
|
||||
expect(assistants[0]?.content).toBe(
|
||||
`msg-${expectedFirstAssistantIndex}`,
|
||||
);
|
||||
expect(assistants[assistants.length - 1]?.content).toBe(
|
||||
`msg-${assistantCount}`,
|
||||
);
|
||||
expect(streamed[streamed.length - 1]?.type).toBe("result");
|
||||
} finally {
|
||||
session.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
191
src/session.ts
191
src/session.ts
@@ -33,6 +33,8 @@ function sessionLog(tag: string, ...args: unknown[]) {
|
||||
if (process.env.DEBUG_SDK) console.error(`[SDK-Session] [${tag}]`, ...args);
|
||||
}
|
||||
|
||||
const MAX_BUFFERED_STREAM_MESSAGES = 100;
|
||||
|
||||
export class Session implements AsyncDisposable {
|
||||
private transport: SubprocessTransport;
|
||||
private _agentId: string | null = null;
|
||||
@@ -40,7 +42,11 @@ export class Session implements AsyncDisposable {
|
||||
private _conversationId: string | null = null;
|
||||
private initialized = false;
|
||||
private externalTools: Map<string, AnyAgentTool> = new Map();
|
||||
|
||||
private streamQueue: SDKMessage[] = [];
|
||||
private streamResolvers: Array<(msg: SDKMessage | null) => void> = [];
|
||||
private pumpPromise: Promise<void> | null = null;
|
||||
private pumpClosed = false;
|
||||
private droppedStreamMessages = 0;
|
||||
|
||||
constructor(
|
||||
private options: InternalSessionOptions = {}
|
||||
@@ -79,6 +85,16 @@ export class Session implements AsyncDisposable {
|
||||
sessionLog("init", "waiting for init message from CLI...");
|
||||
for await (const msg of this.transport.messages()) {
|
||||
sessionLog("init", `received wire message: type=${msg.type}`);
|
||||
|
||||
if (msg.type === "control_request") {
|
||||
const handled = await this.handleControlRequest(msg as ControlRequest);
|
||||
if (!handled) {
|
||||
const wireMsgAny = msg as unknown as Record<string, unknown>;
|
||||
sessionLog("init", `DROPPED unsupported control_request: subtype=${(wireMsgAny.request as Record<string, unknown>)?.subtype || "N/A"}`);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg.type === "system" && "subtype" in msg && msg.subtype === "init") {
|
||||
const initMsg = msg as WireMessage & {
|
||||
agent_id: string;
|
||||
@@ -91,6 +107,7 @@ export class Session implements AsyncDisposable {
|
||||
this._sessionId = initMsg.session_id;
|
||||
this._conversationId = initMsg.conversation_id;
|
||||
this.initialized = true;
|
||||
this.startBackgroundPump();
|
||||
|
||||
// Register external tools with CLI
|
||||
if (this.externalTools.size > 0) {
|
||||
@@ -160,66 +177,144 @@ export class Session implements AsyncDisposable {
|
||||
async *stream(): AsyncGenerator<SDKMessage> {
|
||||
const streamStart = Date.now();
|
||||
let yieldCount = 0;
|
||||
let dropCount = 0;
|
||||
let gotResult = false;
|
||||
|
||||
this.startBackgroundPump();
|
||||
sessionLog("stream", `starting stream (agent=${this._agentId}, conversation=${this._conversationId})`);
|
||||
|
||||
for await (const wireMsg of this.transport.messages()) {
|
||||
// Handle CLI → SDK control requests (e.g., can_use_tool, execute_external_tool)
|
||||
if (wireMsg.type === "control_request") {
|
||||
const controlReq = wireMsg as ControlRequest;
|
||||
// Widen to string to allow SDK-extension subtypes not in the protocol union
|
||||
const subtype: string = controlReq.request.subtype;
|
||||
sessionLog("stream", `control_request: subtype=${subtype} tool=${(controlReq.request as CanUseToolControlRequest).tool_name || "N/A"}`);
|
||||
|
||||
if (subtype === "can_use_tool") {
|
||||
await this.handleCanUseTool(
|
||||
controlReq.request_id,
|
||||
controlReq.request as CanUseToolControlRequest
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (subtype === "execute_external_tool") {
|
||||
// SDK extension: not in protocol ControlRequestBody union, extract fields via Record
|
||||
const rawReq = controlReq.request as Record<string, unknown>;
|
||||
await this.handleExecuteExternalTool(
|
||||
controlReq.request_id,
|
||||
{
|
||||
subtype: "execute_external_tool",
|
||||
tool_call_id: rawReq.tool_call_id as string,
|
||||
tool_name: rawReq.tool_name as string,
|
||||
input: rawReq.input as Record<string, unknown>,
|
||||
}
|
||||
);
|
||||
continue;
|
||||
}
|
||||
while (true) {
|
||||
const sdkMsg = await this.nextBufferedMessage();
|
||||
if (!sdkMsg) {
|
||||
break;
|
||||
}
|
||||
|
||||
const sdkMsg = this.transformMessage(wireMsg);
|
||||
if (sdkMsg) {
|
||||
yieldCount++;
|
||||
sessionLog("stream", `yield #${yieldCount}: type=${sdkMsg.type}${sdkMsg.type === "result" ? ` success=${(sdkMsg as SDKResultMessage).success} error=${(sdkMsg as SDKResultMessage).error || "none"}` : ""}`);
|
||||
yield sdkMsg;
|
||||
yieldCount++;
|
||||
sessionLog("stream", `yield #${yieldCount}: type=${sdkMsg.type}${sdkMsg.type === "result" ? ` success=${(sdkMsg as SDKResultMessage).success} error=${(sdkMsg as SDKResultMessage).error || "none"}` : ""}`);
|
||||
yield sdkMsg;
|
||||
|
||||
// Stop on result message
|
||||
if (sdkMsg.type === "result") {
|
||||
gotResult = true;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
dropCount++;
|
||||
const wireMsgAny = wireMsg as unknown as Record<string, unknown>;
|
||||
sessionLog("stream", `DROPPED wire message #${dropCount}: type=${wireMsg.type} message_type=${wireMsgAny.message_type || "N/A"} subtype=${wireMsgAny.subtype || "N/A"}`);
|
||||
// Stop on result message
|
||||
if (sdkMsg.type === "result") {
|
||||
gotResult = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const elapsed = Date.now() - streamStart;
|
||||
sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} dropped=${dropCount} gotResult=${gotResult}`);
|
||||
sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} dropped=${this.droppedStreamMessages} gotResult=${gotResult}`);
|
||||
if (!gotResult) {
|
||||
sessionLog("stream", `WARNING: stream ended WITHOUT a result message -- transport may have closed unexpectedly`);
|
||||
sessionLog("stream", "WARNING: stream ended WITHOUT a result message -- transport may have closed unexpectedly");
|
||||
}
|
||||
}
|
||||
|
||||
private startBackgroundPump(): void {
|
||||
if (this.pumpPromise) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.pumpClosed = false;
|
||||
this.pumpPromise = this.runBackgroundPump()
|
||||
.catch((err) => {
|
||||
sessionLog("pump", `ERROR: ${err instanceof Error ? err.message : String(err)}`);
|
||||
})
|
||||
.finally(() => {
|
||||
this.pumpClosed = true;
|
||||
this.resolveAllStreamWaiters(null);
|
||||
});
|
||||
}
|
||||
|
||||
private async runBackgroundPump(): Promise<void> {
|
||||
sessionLog("pump", "background pump started");
|
||||
|
||||
for await (const wireMsg of this.transport.messages()) {
|
||||
if (wireMsg.type === "control_request") {
|
||||
const handled = await this.handleControlRequest(wireMsg as ControlRequest);
|
||||
if (!handled) {
|
||||
const wireMsgAny = wireMsg as unknown as Record<string, unknown>;
|
||||
sessionLog("pump", `DROPPED unsupported control_request: subtype=${(wireMsgAny.request as Record<string, unknown>)?.subtype || "N/A"}`);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const sdkMsg = this.transformMessage(wireMsg);
|
||||
if (sdkMsg) {
|
||||
this.enqueueStreamMessage(sdkMsg);
|
||||
} else {
|
||||
const wireMsgAny = wireMsg as unknown as Record<string, unknown>;
|
||||
sessionLog("pump", `DROPPED wire message: type=${wireMsg.type} message_type=${wireMsgAny.message_type || "N/A"} subtype=${wireMsgAny.subtype || "N/A"}`);
|
||||
}
|
||||
}
|
||||
|
||||
sessionLog("pump", "background pump ended");
|
||||
}
|
||||
|
||||
private async handleControlRequest(controlReq: ControlRequest): Promise<boolean> {
|
||||
// Widen to string to allow SDK-extension subtypes not in the protocol union
|
||||
const subtype: string = controlReq.request.subtype;
|
||||
sessionLog("pump", `control_request: subtype=${subtype} tool=${(controlReq.request as CanUseToolControlRequest).tool_name || "N/A"}`);
|
||||
|
||||
if (subtype === "can_use_tool") {
|
||||
await this.handleCanUseTool(
|
||||
controlReq.request_id,
|
||||
controlReq.request as CanUseToolControlRequest
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (subtype === "execute_external_tool") {
|
||||
// SDK extension: not in protocol ControlRequestBody union, extract fields via Record
|
||||
const rawReq = controlReq.request as Record<string, unknown>;
|
||||
await this.handleExecuteExternalTool(
|
||||
controlReq.request_id,
|
||||
{
|
||||
subtype: "execute_external_tool",
|
||||
tool_call_id: rawReq.tool_call_id as string,
|
||||
tool_name: rawReq.tool_name as string,
|
||||
input: rawReq.input as Record<string, unknown>,
|
||||
}
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private enqueueStreamMessage(msg: SDKMessage): void {
|
||||
if (this.streamResolvers.length > 0) {
|
||||
const resolve = this.streamResolvers.shift()!;
|
||||
resolve(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.streamQueue.length >= MAX_BUFFERED_STREAM_MESSAGES) {
|
||||
this.streamQueue.shift();
|
||||
this.droppedStreamMessages++;
|
||||
sessionLog("pump", `stream queue overflow: dropped oldest message (total_dropped=${this.droppedStreamMessages}, max=${MAX_BUFFERED_STREAM_MESSAGES})`);
|
||||
}
|
||||
|
||||
this.streamQueue.push(msg);
|
||||
}
|
||||
|
||||
private async nextBufferedMessage(): Promise<SDKMessage | null> {
|
||||
if (this.streamQueue.length > 0) {
|
||||
return this.streamQueue.shift()!;
|
||||
}
|
||||
|
||||
if (this.pumpClosed) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return new Promise((resolve) => {
|
||||
this.streamResolvers.push(resolve);
|
||||
});
|
||||
}
|
||||
|
||||
private resolveAllStreamWaiters(msg: SDKMessage | null): void {
|
||||
for (const resolve of this.streamResolvers) {
|
||||
resolve(msg);
|
||||
}
|
||||
this.streamResolvers = [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Register external tools with the CLI
|
||||
*/
|
||||
@@ -430,6 +525,8 @@ export class Session implements AsyncDisposable {
|
||||
close(): void {
|
||||
sessionLog("close", `closing session (agent=${this._agentId}, conversation=${this._conversationId})`);
|
||||
this.transport.close();
|
||||
this.pumpClosed = true;
|
||||
this.resolveAllStreamWaiters(null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user