fix(listen): normalize inbound user image payloads (#1374)

This commit is contained in:
Charles Packer
2026-03-12 20:37:50 -07:00
committed by GitHub
parent 94ff9f6796
commit e2e82866ad
3 changed files with 169 additions and 7 deletions

View File

@@ -165,8 +165,7 @@ function gatherDirListing(): string {
const lines: string[] = [];
const sorted = [...dirs, ...files];
for (let i = 0; i < sorted.length; i++) {
const entry = sorted[i]!;
for (const [i, entry] of sorted.entries()) {
const isLast = i === sorted.length - 1;
const prefix = isLast ? "└── " : "├── ";
@@ -189,8 +188,7 @@ function gatherDirListing(): string {
return a.name.localeCompare(b.name);
});
const childPrefix = isLast ? " " : "│ ";
for (let j = 0; j < children.length; j++) {
const child = children[j]!;
for (const [j, child] of children.entries()) {
const childIsLast = j === children.length - 1;
const connector = childIsLast ? "└── " : "├── ";
const suffix = child.isDirectory() ? "/" : "";

View File

@@ -0,0 +1,57 @@
import { describe, expect, test } from "bun:test";
import { __listenClientTestUtils } from "../../websocket/listen-client";
describe("listen-client inbound image normalization", () => {
test("normalizes base64 image content through the shared resize path", async () => {
const resize = async (_buffer: Buffer, mediaType: string) => ({
data: "resized-base64-image",
mediaType: mediaType === "image/png" ? "image/jpeg" : mediaType,
width: 1600,
height: 1200,
resized: true,
});
const normalized = await __listenClientTestUtils.normalizeInboundMessages(
[
{
type: "message",
role: "user",
content: [
{ type: "text", text: "describe this" },
{
type: "image",
source: {
type: "base64",
media_type: "image/png",
data: "raw-base64-image",
},
},
],
client_message_id: "cm-image-1",
},
],
resize,
);
expect(normalized).toHaveLength(1);
const message = normalized[0];
if (!message) {
throw new Error("Expected normalized message");
}
expect("content" in message).toBe(true);
if (!("content" in message) || typeof message.content === "string") {
throw new Error("Expected multimodal content");
}
expect(message.content).toEqual([
{ type: "text", text: "describe this" },
{
type: "image",
source: {
type: "base64",
media_type: "image/jpeg",
data: "resized-base64-image",
},
},
]);
});
});

View File

@@ -36,6 +36,7 @@ import {
} from "../agent/turn-recovery-policy";
import { createBuffers } from "../cli/helpers/accumulator";
import { classifyApprovals } from "../cli/helpers/approvalClassification";
import { resizeImageIfNeeded } from "../cli/helpers/imageResize";
import { generatePlanFilePath } from "../cli/helpers/planName";
import { drainStreamWithResume } from "../cli/helpers/stream";
import { INTERRUPTED_BY_USER } from "../constants";
@@ -256,6 +257,10 @@ interface RecoverPendingApprovalsMessage {
conversationId?: string;
}
type InboundMessagePayload =
| (MessageCreate & { client_message_id?: string })
| ApprovalCreate;
interface StatusResponseMessage {
type: "status_response";
currentMode: "default" | "acceptEdits" | "plan" | "bypassPermissions";
@@ -794,7 +799,7 @@ function loadPersistedCwdMap(): Map<string, string> {
try {
const cachePath = getCwdCachePath();
if (!existsSync(cachePath)) return new Map();
const raw = require("fs").readFileSync(cachePath, "utf-8") as string;
const raw = require("node:fs").readFileSync(cachePath, "utf-8") as string;
const parsed = JSON.parse(raw) as Record<string, string>;
// Validate entries: only keep directories that still exist
const map = new Map<string, string>();
@@ -1071,6 +1076,105 @@ function mergeDequeuedBatchContent(
});
}
function isBase64ImageContentPart(part: unknown): part is {
type: "image";
source: { type: "base64"; media_type: string; data: string };
} {
if (!part || typeof part !== "object") {
return false;
}
const candidate = part as {
type?: unknown;
source?: {
type?: unknown;
media_type?: unknown;
data?: unknown;
};
};
return (
candidate.type === "image" &&
!!candidate.source &&
candidate.source.type === "base64" &&
typeof candidate.source.media_type === "string" &&
candidate.source.media_type.length > 0 &&
typeof candidate.source.data === "string" &&
candidate.source.data.length > 0
);
}
async function normalizeMessageContentImages(
content: MessageCreate["content"],
resize: typeof resizeImageIfNeeded = resizeImageIfNeeded,
): Promise<MessageCreate["content"]> {
if (typeof content === "string") {
return content;
}
let didChange = false;
const normalizedParts = await Promise.all(
content.map(async (part) => {
if (!isBase64ImageContentPart(part)) {
return part;
}
const resized = await resize(
Buffer.from(part.source.data, "base64"),
part.source.media_type,
);
if (
resized.data !== part.source.data ||
resized.mediaType !== part.source.media_type
) {
didChange = true;
}
return {
...part,
source: {
...part.source,
type: "base64" as const,
data: resized.data,
media_type: resized.mediaType,
},
};
}),
);
return didChange ? normalizedParts : content;
}
async function normalizeInboundMessages(
messages: InboundMessagePayload[],
resize: typeof resizeImageIfNeeded = resizeImageIfNeeded,
): Promise<InboundMessagePayload[]> {
let didChange = false;
const normalizedMessages = await Promise.all(
messages.map(async (message) => {
if (!("content" in message)) {
return message;
}
const normalizedContent = await normalizeMessageContentImages(
message.content,
resize,
);
if (normalizedContent !== message.content) {
didChange = true;
return {
...message,
content: normalizedContent,
};
}
return message;
}),
);
return didChange ? normalizedMessages : messages;
}
function getPrimaryQueueMessageItem(items: QueueItem[]): QueueItem | null {
for (const item of items) {
if (item.kind === "message") {
@@ -3168,6 +3272,7 @@ async function handleIncomingMessage(
onStatusChange?.("processing", connectionId);
}
const normalizedMessages = await normalizeInboundMessages(msg.messages);
const messagesToSend: Array<MessageCreate | ApprovalCreate> = [];
let turnToolContextId: string | null = null;
let queuedInterruptedToolCallIds: string[] = [];
@@ -3183,9 +3288,9 @@ async function handleIncomingMessage(
queuedInterruptedToolCallIds = consumed.interruptedToolCallIds;
}
messagesToSend.push(...msg.messages);
messagesToSend.push(...normalizedMessages);
const firstMessage = msg.messages[0];
const firstMessage = normalizedMessages[0];
const isApprovalMessage =
firstMessage &&
"type" in firstMessage &&
@@ -3865,4 +3970,6 @@ export const __listenClientTestUtils = {
normalizeToolReturnWireMessage,
normalizeExecutionResultsForInterruptParity,
shouldAttemptPostStopApprovalRecovery,
normalizeMessageContentImages,
normalizeInboundMessages,
};