fix: add generation counter and run_id pass-through to prevent N-1 desync (#69)

Co-authored-by: letta-code <248085862+letta-code@users.noreply.github.com>
Co-authored-by: Cameron <cpfiffer@users.noreply.github.com>
This commit is contained in:
Cameron
2026-03-03 11:09:28 -08:00
committed by GitHub
parent 6022cd8af0
commit 30c539e1e3
4 changed files with 389 additions and 15 deletions

View File

@@ -41,6 +41,12 @@ function sessionLog(tag: string, ...args: unknown[]) {
const MAX_BUFFERED_STREAM_MESSAGES = 100;
type BufferedStreamMessage = {
message: SDKMessage;
generation: number;
runId?: string;
};
export class Session implements AsyncDisposable {
private transport: SubprocessTransport;
private _agentId: string | null = null;
@@ -48,11 +54,19 @@ export class Session implements AsyncDisposable {
private _conversationId: string | null = null;
private initialized = false;
private externalTools: Map<string, AnyAgentTool> = new Map();
private streamQueue: SDKMessage[] = [];
private streamResolvers: Array<(msg: SDKMessage | null) => void> = [];
private streamQueue: BufferedStreamMessage[] = [];
private streamResolvers: Array<(msg: BufferedStreamMessage | null) => void> = [];
private pumpPromise: Promise<void> | null = null;
private pumpClosed = false;
private droppedStreamMessages = 0;
// Monotonic counter incremented after each send(). Messages enqueued by the
// pump are tagged with the current generation; stream() filters out messages
// from earlier generations to prevent N-1 desync (stale events from a
// previous run leaking into the current run's stream).
private sendGeneration = 0;
// Run IDs that completed in the previous streamed turn. Used to drop
// late-arriving stale events from the old run if they arrive after send().
private lastCompletedRunIds = new Set<string>();
// Waiters for SDK-initiated control requests (e.g., listMessages).
// Keyed by request_id; pump resolves the matching waiter when it sees
// a control_response with that request_id instead of queuing it as a stream msg.
@@ -206,7 +220,12 @@ export class Session implements AsyncDisposable {
type: "user",
message: { role: "user", content: message },
});
sessionLog("send", "message written to transport");
// Advance generation AFTER the write so any messages the pump enqueues
// during the await (from the previous run's lingering events) are tagged
// with the old generation and will be filtered by stream().
this.sendGeneration++;
sessionLog("send", `message written to transport (generation=${this.sendGeneration})`);
}
/**
@@ -214,18 +233,44 @@ export class Session implements AsyncDisposable {
*/
async *stream(): AsyncGenerator<SDKMessage> {
const streamStart = Date.now();
const minGeneration = this.sendGeneration;
let yieldCount = 0;
let staleCount = 0;
let staleRunIdCount = 0;
let gotResult = false;
const currentStreamRunIds = new Set<string>();
const staleRunIds = new Set(this.lastCompletedRunIds);
this.startBackgroundPump();
sessionLog("stream", `starting stream (agent=${this._agentId}, conversation=${this._conversationId})`);
sessionLog("stream", `starting stream (agent=${this._agentId}, conversation=${this._conversationId}, generation=${minGeneration})`);
while (true) {
const sdkMsg = await this.nextBufferedMessage();
if (!sdkMsg) {
const bufferedMsg = await this.nextBufferedMessage();
if (!bufferedMsg) {
break;
}
// Filter stale messages from previous runs. Messages enqueued before
// the current send() carry an older generation tag.
if (bufferedMsg.generation < minGeneration) {
staleCount++;
sessionLog("stream", `discarding stale message: type=${bufferedMsg.message.type} generation=${bufferedMsg.generation} (current=${minGeneration})`);
continue;
}
// Filter late old-run messages that arrive after send() has already
// advanced generation and stream queue was cleared.
if (bufferedMsg.runId && staleRunIds.has(bufferedMsg.runId)) {
staleRunIdCount++;
sessionLog("stream", `discarding stale message: type=${bufferedMsg.message.type} runId=${bufferedMsg.runId}`);
continue;
}
if (bufferedMsg.runId) {
currentStreamRunIds.add(bufferedMsg.runId);
}
const sdkMsg = bufferedMsg.message;
yieldCount++;
sessionLog("stream", `yield #${yieldCount}: type=${sdkMsg.type}${sdkMsg.type === "result" ? ` success=${(sdkMsg as SDKResultMessage).success} error=${(sdkMsg as SDKResultMessage).error || "none"}` : ""}`);
yield sdkMsg;
@@ -233,12 +278,13 @@ export class Session implements AsyncDisposable {
// Stop on result message
if (sdkMsg.type === "result") {
gotResult = true;
this.updateCompletedRunIds((sdkMsg as SDKResultMessage).runIds, currentStreamRunIds);
break;
}
}
const elapsed = Date.now() - streamStart;
sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} dropped=${this.droppedStreamMessages} gotResult=${gotResult}`);
sessionLog("stream", `stream ended: duration=${elapsed}ms yielded=${yieldCount} staleFiltered=${staleCount} staleRunIdFiltered=${staleRunIdCount} dropped=${this.droppedStreamMessages} gotResult=${gotResult}`);
if (!gotResult) {
sessionLog("stream", "WARNING: stream ended WITHOUT a result message -- transport may have closed unexpectedly");
}
@@ -374,9 +420,15 @@ export class Session implements AsyncDisposable {
}
private enqueueStreamMessage(msg: SDKMessage): void {
const bufferedMsg: BufferedStreamMessage = {
message: msg,
generation: this.sendGeneration,
runId: this.getMessageRunId(msg),
};
if (this.streamResolvers.length > 0) {
const resolve = this.streamResolvers.shift()!;
resolve(msg);
resolve(bufferedMsg);
return;
}
@@ -386,10 +438,10 @@ export class Session implements AsyncDisposable {
sessionLog("pump", `stream queue overflow: dropped oldest message (total_dropped=${this.droppedStreamMessages}, max=${MAX_BUFFERED_STREAM_MESSAGES})`);
}
this.streamQueue.push(msg);
this.streamQueue.push(bufferedMsg);
}
private async nextBufferedMessage(): Promise<SDKMessage | null> {
private async nextBufferedMessage(): Promise<BufferedStreamMessage | null> {
if (this.streamQueue.length > 0) {
return this.streamQueue.shift()!;
}
@@ -403,7 +455,7 @@ export class Session implements AsyncDisposable {
});
}
private resolveAllStreamWaiters(msg: SDKMessage | null): void {
private resolveAllStreamWaiters(msg: BufferedStreamMessage | null): void {
for (const resolve of this.streamResolvers) {
resolve(msg);
}
@@ -415,6 +467,43 @@ export class Session implements AsyncDisposable {
this.controlResponseWaiters.clear();
}
private getMessageRunId(msg: SDKMessage): string | undefined {
switch (msg.type) {
case "assistant":
case "tool_call":
case "tool_result":
case "reasoning":
case "error":
case "retry":
return msg.runId;
default:
return undefined;
}
}
private updateCompletedRunIds(
resultRunIds: string[] | undefined,
streamedRunIds: Set<string>,
): void {
const nextRunIds = new Set<string>();
if (Array.isArray(resultRunIds)) {
for (const runId of resultRunIds) {
if (runId) {
nextRunIds.add(runId);
}
}
}
for (const runId of streamedRunIds) {
if (runId) {
nextRunIds.add(runId);
}
}
this.lastCompletedRunIds = nextRunIds;
}
/**
* Register external tools with the CLI
*/
@@ -843,6 +932,7 @@ export class Session implements AsyncDisposable {
const msg = wireMsg as WireMessage & {
message_type: string;
uuid: string;
run_id?: string;
// assistant_message fields
content?: string;
// tool_call_message fields
@@ -856,12 +946,15 @@ export class Session implements AsyncDisposable {
reasoning?: string;
};
const runId = msg.run_id || undefined;
// Assistant message
if (msg.message_type === "assistant_message" && msg.content) {
return {
type: "assistant",
content: msg.content,
uuid: msg.uuid,
runId,
};
}
@@ -899,6 +992,7 @@ export class Session implements AsyncDisposable {
toolInput,
rawArguments: toolArgs || undefined,
uuid: msg.uuid,
runId,
};
}
}
@@ -911,6 +1005,7 @@ export class Session implements AsyncDisposable {
content: msg.tool_return || "",
isError: msg.status === "error",
uuid: msg.uuid,
runId,
};
}
@@ -920,6 +1015,7 @@ export class Session implements AsyncDisposable {
type: "reasoning",
content: msg.reasoning,
uuid: msg.uuid,
runId,
};
}
}
@@ -947,7 +1043,11 @@ export class Session implements AsyncDisposable {
total_cost_usd?: number;
conversation_id: string;
stop_reason?: string;
run_ids?: unknown[];
};
const runIds = Array.isArray(msg.run_ids)
? msg.run_ids.filter((id): id is string => typeof id === "string")
: undefined;
return {
type: "result",
success: msg.subtype === "success",
@@ -957,6 +1057,7 @@ export class Session implements AsyncDisposable {
durationMs: msg.duration_ms,
totalCostUsd: msg.total_cost_usd,
conversationId: msg.conversation_id,
runIds,
};
}

View File

@@ -87,12 +87,20 @@ function createInitMessage(
} as WireMessage;
}
function createAssistantMessage(index: number): WireMessage {
function createAssistantMessage(
index: number,
overrides: Partial<{
uuid: string;
content: string;
run_id: string;
}> = {},
): WireMessage {
return {
type: "message",
message_type: "assistant_message",
uuid: `assistant-${index}`,
content: `msg-${index}`,
...overrides,
} as WireMessage;
}
@@ -116,7 +124,16 @@ function createApprovalRequestMessage(
};
}
function createResultMessage(): WireMessage {
function createResultMessage(
overrides: Partial<{
subtype: string;
result: string | null;
duration_ms: number;
conversation_id: string;
stop_reason: string;
run_ids: unknown[];
}> = {},
): WireMessage {
return {
type: "result",
subtype: "success",
@@ -124,6 +141,7 @@ function createResultMessage(): WireMessage {
duration_ms: 1,
conversation_id: "conversation-1",
stop_reason: "end_turn",
...overrides,
} as WireMessage;
}
@@ -434,6 +452,45 @@ describe("Session", () => {
});
});
describe("transformMessage result mapping", () => {
test("maps result wire message run_ids to SDK runIds", () => {
const session = new Session();
const wireMsg = createResultMessage({
run_ids: ["run-1", "run-2"],
});
// @ts-expect-error - accessing private method for regression coverage
const transformed = session.transformMessage(wireMsg) as SDKMessage | null;
expect(transformed).toEqual({
type: "result",
success: true,
result: "done",
error: undefined,
stopReason: "end_turn",
durationMs: 1,
totalCostUsd: undefined,
conversationId: "conversation-1",
runIds: ["run-1", "run-2"],
});
});
test("filters non-string run_ids and preserves valid values", () => {
const session = new Session();
const wireMsg = createResultMessage({
run_ids: ["run-1", 42, null, "run-2"],
});
// @ts-expect-error - accessing private method for regression coverage
const transformed = session.transformMessage(wireMsg) as SDKMessage | null;
expect(transformed).toMatchObject({
type: "result",
runIds: ["run-1", "run-2"],
});
});
});
describe("transformMessage error/retry mapping", () => {
test("maps error wire message to SDK error message", () => {
const session = new Session();
@@ -608,4 +665,196 @@ describe("Session", () => {
}
});
});
describe("generation-based stale message filtering", () => {
test("filters stale messages that arrive late from the previous run_id", async () => {
const session = new Session();
const transport = new MockTransport();
attachMockTransport(session, transport);
try {
transport.push(createInitMessage());
await session.initialize();
// First send + stream establishes run-1 as completed.
transport.push(createAssistantMessage(1, { run_id: "run-1" }));
transport.push(
createResultMessage({
result: "first",
run_ids: ["run-1"],
}),
);
await session.send("first message");
const firstMessages: SDKMessage[] = [];
for await (const msg of session.stream()) {
firstMessages.push(msg);
}
expect(firstMessages).toHaveLength(2);
// Second send starts a new run, but an old run-1 message arrives late.
await session.send("second message");
transport.push(
createAssistantMessage(999, {
uuid: "assistant-stale-old-run",
content: "stale-old-run",
run_id: "run-1",
}),
);
transport.push(createAssistantMessage(2, { run_id: "run-2" }));
transport.push(
createResultMessage({
result: "second",
run_ids: ["run-2"],
}),
);
const secondMessages: SDKMessage[] = [];
for await (const msg of session.stream()) {
secondMessages.push(msg);
}
// The stale run-1 message should be filtered; only fresh run-2 messages remain.
expect(secondMessages).toHaveLength(2);
expect((secondMessages[0] as { content: string }).content).toBe("msg-2");
expect(secondMessages[1]?.type).toBe("result");
} finally {
session.close();
}
});
test("does not leak internal generation metadata on emitted SDK messages", async () => {
const session = new Session();
const transport = new MockTransport();
attachMockTransport(session, transport);
try {
transport.push(createInitMessage());
await session.initialize();
transport.push(createAssistantMessage(1, { run_id: "run-1" }));
transport.push(createResultMessage({ run_ids: ["run-1"] }));
await session.send("hello");
const streamed: SDKMessage[] = [];
for await (const msg of session.stream()) {
streamed.push(msg);
}
const assistant = streamed.find(
(msg): msg is Extract<SDKMessage, { type: "assistant" }> =>
msg.type === "assistant",
);
expect(assistant).toBeDefined();
if (assistant) {
expect(
"_generation" in (assistant as unknown as Record<string, unknown>),
).toBe(
false,
);
expect(Object.keys(assistant)).not.toContain("_generation");
}
} finally {
session.close();
}
});
});
describe("transformMessage run_id pass-through", () => {
test("includes runId on assistant messages", () => {
const session = new Session();
const wireMsg = {
type: "message",
message_type: "assistant_message",
uuid: "a-1",
content: "hello",
run_id: "run-abc",
} as WireMessage;
// @ts-expect-error - accessing private method
const transformed = session.transformMessage(wireMsg);
expect(transformed).toMatchObject({
type: "assistant",
content: "hello",
runId: "run-abc",
});
});
test("includes runId on tool_call messages", () => {
const session = new Session();
const wireMsg = {
type: "message",
message_type: "tool_call_message",
uuid: "tc-1",
run_id: "run-abc",
tool_calls: [{
tool_call_id: "call-1",
name: "Edit",
arguments: "{}",
}],
} as WireMessage;
// @ts-expect-error - accessing private method
const transformed = session.transformMessage(wireMsg);
expect(transformed).toMatchObject({
type: "tool_call",
toolName: "Edit",
runId: "run-abc",
});
});
test("includes runId on reasoning messages", () => {
const session = new Session();
const wireMsg = {
type: "message",
message_type: "reasoning_message",
uuid: "r-1",
reasoning: "thinking...",
run_id: "run-abc",
} as WireMessage;
// @ts-expect-error - accessing private method
const transformed = session.transformMessage(wireMsg);
expect(transformed).toMatchObject({
type: "reasoning",
content: "thinking...",
runId: "run-abc",
});
});
test("includes runId on tool_result messages", () => {
const session = new Session();
const wireMsg = {
type: "message",
message_type: "tool_return_message",
uuid: "tr-1",
tool_call_id: "call-1",
tool_return: "success",
status: "success",
run_id: "run-abc",
} as WireMessage;
// @ts-expect-error - accessing private method
const transformed = session.transformMessage(wireMsg);
expect(transformed).toMatchObject({
type: "tool_result",
runId: "run-abc",
});
});
test("runId is undefined when wire message lacks run_id", () => {
const session = new Session();
const wireMsg = {
type: "message",
message_type: "assistant_message",
uuid: "a-2",
content: "no run id",
} as WireMessage;
// @ts-expect-error - accessing private method
const transformed = session.transformMessage(wireMsg);
expect(transformed).toMatchObject({ type: "assistant" });
expect((transformed as { runId?: string }).runId).toBeUndefined();
});
});
});

View File

@@ -112,8 +112,22 @@ function reasoningChunk(uuid: string, text = "done"): WireMessage {
}
function queuedMessages(session: Session) {
return ((session as unknown as { streamQueue: unknown[] }).streamQueue ??
[]) as Array<Record<string, unknown>>;
const queue =
(session as unknown as { streamQueue?: unknown[] }).streamQueue ?? [];
return queue.map((entry) => {
if (
entry &&
typeof entry === "object" &&
"message" in entry &&
(entry as { message?: unknown }).message &&
typeof (entry as { message?: unknown }).message === "object"
) {
return (entry as { message: Record<string, unknown> }).message;
}
return entry as Record<string, unknown>;
});
}
describe("tool call streaming passthrough", () => {

View File

@@ -453,6 +453,8 @@ export interface SDKAssistantMessage {
type: "assistant";
content: string;
uuid: string;
/** Run ID from the Letta API for this event (used for stale-run detection). */
runId?: string;
}
export interface SDKToolCallMessage {
@@ -463,6 +465,8 @@ export interface SDKToolCallMessage {
/** Raw unparsed arguments string from the wire for consumer-side accumulation. */
rawArguments?: string;
uuid: string;
/** Run ID from the Letta API for this event (used for stale-run detection). */
runId?: string;
}
export interface SDKToolResultMessage {
@@ -471,12 +475,16 @@ export interface SDKToolResultMessage {
content: string;
isError: boolean;
uuid: string;
/** Run ID from the Letta API for this event (used for stale-run detection). */
runId?: string;
}
export interface SDKReasoningMessage {
type: "reasoning";
content: string;
uuid: string;
/** Run ID from the Letta API for this event (used for stale-run detection). */
runId?: string;
}
export interface SDKResultMessage {
@@ -488,6 +496,8 @@ export interface SDKResultMessage {
durationMs: number;
totalCostUsd?: number;
conversationId: string | null;
/** Run IDs associated with this turn (if provided by the CLI). */
runIds?: string[];
}
export interface SDKStreamEventDeltaPayload {