diff --git a/src/session.ts b/src/session.ts index 4fea98a..4772af9 100644 --- a/src/session.ts +++ b/src/session.ts @@ -41,6 +41,12 @@ function sessionLog(tag: string, ...args: unknown[]) { const MAX_BUFFERED_STREAM_MESSAGES = 100; +type BufferedStreamMessage = { + message: SDKMessage; + generation: number; + runId?: string; +}; + export class Session implements AsyncDisposable { private transport: SubprocessTransport; private _agentId: string | null = null; @@ -48,11 +54,19 @@ export class Session implements AsyncDisposable { private _conversationId: string | null = null; private initialized = false; private externalTools: Map = new Map(); - private streamQueue: SDKMessage[] = []; - private streamResolvers: Array<(msg: SDKMessage | null) => void> = []; + private streamQueue: BufferedStreamMessage[] = []; + private streamResolvers: Array<(msg: BufferedStreamMessage | null) => void> = []; private pumpPromise: Promise | null = null; private pumpClosed = false; private droppedStreamMessages = 0; + // Monotonic counter incremented after each send(). Messages enqueued by the + // pump are tagged with the current generation; stream() filters out messages + // from earlier generations to prevent N-1 desync (stale events from a + // previous run leaking into the current run's stream). + private sendGeneration = 0; + // Run IDs that completed in the previous streamed turn. Used to drop + // late-arriving stale events from the old run if they arrive after send(). + private lastCompletedRunIds = new Set(); // Waiters for SDK-initiated control requests (e.g., listMessages). // Keyed by request_id; pump resolves the matching waiter when it sees // a control_response with that request_id instead of queuing it as a stream msg. @@ -206,7 +220,12 @@ export class Session implements AsyncDisposable { type: "user", message: { role: "user", content: message }, }); - sessionLog("send", "message written to transport"); + + // Advance generation AFTER the write so any messages the pump enqueues + // during the await (from the previous run's lingering events) are tagged + // with the old generation and will be filtered by stream(). + this.sendGeneration++; + sessionLog("send", `message written to transport (generation=${this.sendGeneration})`); } /** @@ -214,18 +233,44 @@ export class Session implements AsyncDisposable { */ async *stream(): AsyncGenerator { const streamStart = Date.now(); + const minGeneration = this.sendGeneration; let yieldCount = 0; + let staleCount = 0; + let staleRunIdCount = 0; let gotResult = false; + const currentStreamRunIds = new Set(); + const staleRunIds = new Set(this.lastCompletedRunIds); this.startBackgroundPump(); - sessionLog("stream", `starting stream (agent=${this._agentId}, conversation=${this._conversationId})`); + sessionLog("stream", `starting stream (agent=${this._agentId}, conversation=${this._conversationId}, generation=${minGeneration})`); while (true) { - const sdkMsg = await this.nextBufferedMessage(); - if (!sdkMsg) { + const bufferedMsg = await this.nextBufferedMessage(); + if (!bufferedMsg) { break; } + // Filter stale messages from previous runs. Messages enqueued before + // the current send() carry an older generation tag. + if (bufferedMsg.generation < minGeneration) { + staleCount++; + sessionLog("stream", `discarding stale message: type=${bufferedMsg.message.type} generation=${bufferedMsg.generation} (current=${minGeneration})`); + continue; + } + + // Filter late old-run messages that arrive after send() has already + // advanced generation and stream queue was cleared. + if (bufferedMsg.runId && staleRunIds.has(bufferedMsg.runId)) { + staleRunIdCount++; + sessionLog("stream", `discarding stale message: type=${bufferedMsg.message.type} runId=${bufferedMsg.runId}`); + continue; + } + + if (bufferedMsg.runId) { + currentStreamRunIds.add(bufferedMsg.runId); + } + + const sdkMsg = bufferedMsg.message; yieldCount++; sessionLog("stream", `yield #${yieldCount}: type=${sdkMsg.type}${sdkMsg.type === "result" ? ` success=${(sdkMsg as SDKResultMessage).success} error=${(sdkMsg as SDKResultMessage).error || "none"}` : ""}`); yield sdkMsg; @@ -233,12 +278,13 @@ export class Session implements AsyncDisposable { // Stop on result message if (sdkMsg.type === "result") { gotResult = true; + this.updateCompletedRunIds((sdkMsg as SDKResultMessage).runIds, currentStreamRunIds); break; } } const elapsed = Date.now() - streamStart; - sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} dropped=${this.droppedStreamMessages} gotResult=${gotResult}`); + sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} staleFiltered=${staleCount} staleRunIdFiltered=${staleRunIdCount} dropped=${this.droppedStreamMessages} gotResult=${gotResult}`); if (!gotResult) { sessionLog("stream", "WARNING: stream ended WITHOUT a result message -- transport may have closed unexpectedly"); } @@ -374,9 +420,15 @@ export class Session implements AsyncDisposable { } private enqueueStreamMessage(msg: SDKMessage): void { + const bufferedMsg: BufferedStreamMessage = { + message: msg, + generation: this.sendGeneration, + runId: this.getMessageRunId(msg), + }; + if (this.streamResolvers.length > 0) { const resolve = this.streamResolvers.shift()!; - resolve(msg); + resolve(bufferedMsg); return; } @@ -386,10 +438,10 @@ export class Session implements AsyncDisposable { sessionLog("pump", `stream queue overflow: dropped oldest message (total_dropped=${this.droppedStreamMessages}, max=${MAX_BUFFERED_STREAM_MESSAGES})`); } - this.streamQueue.push(msg); + this.streamQueue.push(bufferedMsg); } - private async nextBufferedMessage(): Promise { + private async nextBufferedMessage(): Promise { if (this.streamQueue.length > 0) { return this.streamQueue.shift()!; } @@ -403,7 +455,7 @@ export class Session implements AsyncDisposable { }); } - private resolveAllStreamWaiters(msg: SDKMessage | null): void { + private resolveAllStreamWaiters(msg: BufferedStreamMessage | null): void { for (const resolve of this.streamResolvers) { resolve(msg); } @@ -415,6 +467,43 @@ export class Session implements AsyncDisposable { this.controlResponseWaiters.clear(); } + private getMessageRunId(msg: SDKMessage): string | undefined { + switch (msg.type) { + case "assistant": + case "tool_call": + case "tool_result": + case "reasoning": + case "error": + case "retry": + return msg.runId; + default: + return undefined; + } + } + + private updateCompletedRunIds( + resultRunIds: string[] | undefined, + streamedRunIds: Set, + ): void { + const nextRunIds = new Set(); + + if (Array.isArray(resultRunIds)) { + for (const runId of resultRunIds) { + if (runId) { + nextRunIds.add(runId); + } + } + } + + for (const runId of streamedRunIds) { + if (runId) { + nextRunIds.add(runId); + } + } + + this.lastCompletedRunIds = nextRunIds; + } + /** * Register external tools with the CLI */ @@ -843,6 +932,7 @@ export class Session implements AsyncDisposable { const msg = wireMsg as WireMessage & { message_type: string; uuid: string; + run_id?: string; // assistant_message fields content?: string; // tool_call_message fields @@ -856,12 +946,15 @@ export class Session implements AsyncDisposable { reasoning?: string; }; + const runId = msg.run_id || undefined; + // Assistant message if (msg.message_type === "assistant_message" && msg.content) { return { type: "assistant", content: msg.content, uuid: msg.uuid, + runId, }; } @@ -899,6 +992,7 @@ export class Session implements AsyncDisposable { toolInput, rawArguments: toolArgs || undefined, uuid: msg.uuid, + runId, }; } } @@ -911,6 +1005,7 @@ export class Session implements AsyncDisposable { content: msg.tool_return || "", isError: msg.status === "error", uuid: msg.uuid, + runId, }; } @@ -920,6 +1015,7 @@ export class Session implements AsyncDisposable { type: "reasoning", content: msg.reasoning, uuid: msg.uuid, + runId, }; } } @@ -947,7 +1043,11 @@ export class Session implements AsyncDisposable { total_cost_usd?: number; conversation_id: string; stop_reason?: string; + run_ids?: unknown[]; }; + const runIds = Array.isArray(msg.run_ids) + ? msg.run_ids.filter((id): id is string => typeof id === "string") + : undefined; return { type: "result", success: msg.subtype === "success", @@ -957,6 +1057,7 @@ export class Session implements AsyncDisposable { durationMs: msg.duration_ms, totalCostUsd: msg.total_cost_usd, conversationId: msg.conversation_id, + runIds, }; } diff --git a/src/tests/session.test.ts b/src/tests/session.test.ts index 3c99abb..2d3a0c4 100644 --- a/src/tests/session.test.ts +++ b/src/tests/session.test.ts @@ -87,12 +87,20 @@ function createInitMessage( } as WireMessage; } -function createAssistantMessage(index: number): WireMessage { +function createAssistantMessage( + index: number, + overrides: Partial<{ + uuid: string; + content: string; + run_id: string; + }> = {}, +): WireMessage { return { type: "message", message_type: "assistant_message", uuid: `assistant-${index}`, content: `msg-${index}`, + ...overrides, } as WireMessage; } @@ -116,7 +124,16 @@ function createApprovalRequestMessage( }; } -function createResultMessage(): WireMessage { +function createResultMessage( + overrides: Partial<{ + subtype: string; + result: string | null; + duration_ms: number; + conversation_id: string; + stop_reason: string; + run_ids: unknown[]; + }> = {}, +): WireMessage { return { type: "result", subtype: "success", @@ -124,6 +141,7 @@ function createResultMessage(): WireMessage { duration_ms: 1, conversation_id: "conversation-1", stop_reason: "end_turn", + ...overrides, } as WireMessage; } @@ -434,6 +452,45 @@ describe("Session", () => { }); }); + describe("transformMessage result mapping", () => { + test("maps result wire message run_ids to SDK runIds", () => { + const session = new Session(); + const wireMsg = createResultMessage({ + run_ids: ["run-1", "run-2"], + }); + + // @ts-expect-error - accessing private method for regression coverage + const transformed = session.transformMessage(wireMsg) as SDKMessage | null; + + expect(transformed).toEqual({ + type: "result", + success: true, + result: "done", + error: undefined, + stopReason: "end_turn", + durationMs: 1, + totalCostUsd: undefined, + conversationId: "conversation-1", + runIds: ["run-1", "run-2"], + }); + }); + + test("filters non-string run_ids and preserves valid values", () => { + const session = new Session(); + const wireMsg = createResultMessage({ + run_ids: ["run-1", 42, null, "run-2"], + }); + + // @ts-expect-error - accessing private method for regression coverage + const transformed = session.transformMessage(wireMsg) as SDKMessage | null; + + expect(transformed).toMatchObject({ + type: "result", + runIds: ["run-1", "run-2"], + }); + }); + }); + describe("transformMessage error/retry mapping", () => { test("maps error wire message to SDK error message", () => { const session = new Session(); @@ -608,4 +665,196 @@ describe("Session", () => { } }); }); + + describe("generation-based stale message filtering", () => { + test("filters stale messages that arrive late from the previous run_id", async () => { + const session = new Session(); + const transport = new MockTransport(); + attachMockTransport(session, transport); + + try { + transport.push(createInitMessage()); + await session.initialize(); + + // First send + stream establishes run-1 as completed. + transport.push(createAssistantMessage(1, { run_id: "run-1" })); + transport.push( + createResultMessage({ + result: "first", + run_ids: ["run-1"], + }), + ); + await session.send("first message"); + + const firstMessages: SDKMessage[] = []; + for await (const msg of session.stream()) { + firstMessages.push(msg); + } + expect(firstMessages).toHaveLength(2); + + // Second send starts a new run, but an old run-1 message arrives late. + await session.send("second message"); + transport.push( + createAssistantMessage(999, { + uuid: "assistant-stale-old-run", + content: "stale-old-run", + run_id: "run-1", + }), + ); + transport.push(createAssistantMessage(2, { run_id: "run-2" })); + transport.push( + createResultMessage({ + result: "second", + run_ids: ["run-2"], + }), + ); + + const secondMessages: SDKMessage[] = []; + for await (const msg of session.stream()) { + secondMessages.push(msg); + } + + // The stale run-1 message should be filtered; only fresh run-2 messages remain. + expect(secondMessages).toHaveLength(2); + expect((secondMessages[0] as { content: string }).content).toBe("msg-2"); + expect(secondMessages[1]?.type).toBe("result"); + } finally { + session.close(); + } + }); + + test("does not leak internal generation metadata on emitted SDK messages", async () => { + const session = new Session(); + const transport = new MockTransport(); + attachMockTransport(session, transport); + + try { + transport.push(createInitMessage()); + await session.initialize(); + + transport.push(createAssistantMessage(1, { run_id: "run-1" })); + transport.push(createResultMessage({ run_ids: ["run-1"] })); + await session.send("hello"); + + const streamed: SDKMessage[] = []; + for await (const msg of session.stream()) { + streamed.push(msg); + } + + const assistant = streamed.find( + (msg): msg is Extract => + msg.type === "assistant", + ); + expect(assistant).toBeDefined(); + if (assistant) { + expect( + "_generation" in (assistant as unknown as Record), + ).toBe( + false, + ); + expect(Object.keys(assistant)).not.toContain("_generation"); + } + } finally { + session.close(); + } + }); + }); + + describe("transformMessage run_id pass-through", () => { + test("includes runId on assistant messages", () => { + const session = new Session(); + const wireMsg = { + type: "message", + message_type: "assistant_message", + uuid: "a-1", + content: "hello", + run_id: "run-abc", + } as WireMessage; + + // @ts-expect-error - accessing private method + const transformed = session.transformMessage(wireMsg); + expect(transformed).toMatchObject({ + type: "assistant", + content: "hello", + runId: "run-abc", + }); + }); + + test("includes runId on tool_call messages", () => { + const session = new Session(); + const wireMsg = { + type: "message", + message_type: "tool_call_message", + uuid: "tc-1", + run_id: "run-abc", + tool_calls: [{ + tool_call_id: "call-1", + name: "Edit", + arguments: "{}", + }], + } as WireMessage; + + // @ts-expect-error - accessing private method + const transformed = session.transformMessage(wireMsg); + expect(transformed).toMatchObject({ + type: "tool_call", + toolName: "Edit", + runId: "run-abc", + }); + }); + + test("includes runId on reasoning messages", () => { + const session = new Session(); + const wireMsg = { + type: "message", + message_type: "reasoning_message", + uuid: "r-1", + reasoning: "thinking...", + run_id: "run-abc", + } as WireMessage; + + // @ts-expect-error - accessing private method + const transformed = session.transformMessage(wireMsg); + expect(transformed).toMatchObject({ + type: "reasoning", + content: "thinking...", + runId: "run-abc", + }); + }); + + test("includes runId on tool_result messages", () => { + const session = new Session(); + const wireMsg = { + type: "message", + message_type: "tool_return_message", + uuid: "tr-1", + tool_call_id: "call-1", + tool_return: "success", + status: "success", + run_id: "run-abc", + } as WireMessage; + + // @ts-expect-error - accessing private method + const transformed = session.transformMessage(wireMsg); + expect(transformed).toMatchObject({ + type: "tool_result", + runId: "run-abc", + }); + }); + + test("runId is undefined when wire message lacks run_id", () => { + const session = new Session(); + const wireMsg = { + type: "message", + message_type: "assistant_message", + uuid: "a-2", + content: "no run id", + } as WireMessage; + + // @ts-expect-error - accessing private method + const transformed = session.transformMessage(wireMsg); + expect(transformed).toMatchObject({ type: "assistant" }); + expect((transformed as { runId?: string }).runId).toBeUndefined(); + }); + }); }); diff --git a/src/tests/tool-call-args-accumulation.test.ts b/src/tests/tool-call-args-accumulation.test.ts index 0c6ca12..9dd0d36 100644 --- a/src/tests/tool-call-args-accumulation.test.ts +++ b/src/tests/tool-call-args-accumulation.test.ts @@ -112,8 +112,22 @@ function reasoningChunk(uuid: string, text = "done"): WireMessage { } function queuedMessages(session: Session) { - return ((session as unknown as { streamQueue: unknown[] }).streamQueue ?? - []) as Array>; + const queue = + (session as unknown as { streamQueue?: unknown[] }).streamQueue ?? []; + + return queue.map((entry) => { + if ( + entry && + typeof entry === "object" && + "message" in entry && + (entry as { message?: unknown }).message && + typeof (entry as { message?: unknown }).message === "object" + ) { + return (entry as { message: Record }).message; + } + + return entry as Record; + }); } describe("tool call streaming passthrough", () => { diff --git a/src/types.ts b/src/types.ts index 06b6e71..28a42c0 100644 --- a/src/types.ts +++ b/src/types.ts @@ -453,6 +453,8 @@ export interface SDKAssistantMessage { type: "assistant"; content: string; uuid: string; + /** Run ID from the Letta API for this event (used for stale-run detection). */ + runId?: string; } export interface SDKToolCallMessage { @@ -463,6 +465,8 @@ export interface SDKToolCallMessage { /** Raw unparsed arguments string from the wire for consumer-side accumulation. */ rawArguments?: string; uuid: string; + /** Run ID from the Letta API for this event (used for stale-run detection). */ + runId?: string; } export interface SDKToolResultMessage { @@ -471,12 +475,16 @@ export interface SDKToolResultMessage { content: string; isError: boolean; uuid: string; + /** Run ID from the Letta API for this event (used for stale-run detection). */ + runId?: string; } export interface SDKReasoningMessage { type: "reasoning"; content: string; uuid: string; + /** Run ID from the Letta API for this event (used for stale-run detection). */ + runId?: string; } export interface SDKResultMessage { @@ -488,6 +496,8 @@ export interface SDKResultMessage { durationMs: number; totalCostUsd?: number; conversationId: string | null; + /** Run IDs associated with this turn (if provided by the CLI). */ + runIds?: string[]; } export interface SDKStreamEventDeltaPayload {