diff --git a/src/cli/helpers/initCommand.ts b/src/cli/helpers/initCommand.ts index f4d5147..7559889 100644 --- a/src/cli/helpers/initCommand.ts +++ b/src/cli/helpers/initCommand.ts @@ -165,8 +165,7 @@ function gatherDirListing(): string { const lines: string[] = []; const sorted = [...dirs, ...files]; - for (let i = 0; i < sorted.length; i++) { - const entry = sorted[i]!; + for (const [i, entry] of sorted.entries()) { const isLast = i === sorted.length - 1; const prefix = isLast ? "└── " : "├── "; @@ -189,8 +188,7 @@ function gatherDirListing(): string { return a.name.localeCompare(b.name); }); const childPrefix = isLast ? " " : "│ "; - for (let j = 0; j < children.length; j++) { - const child = children[j]!; + for (const [j, child] of children.entries()) { const childIsLast = j === children.length - 1; const connector = childIsLast ? "└── " : "├── "; const suffix = child.isDirectory() ? "/" : ""; diff --git a/src/tests/websocket/listen-client-image-normalization.test.ts b/src/tests/websocket/listen-client-image-normalization.test.ts new file mode 100644 index 0000000..caf20a5 --- /dev/null +++ b/src/tests/websocket/listen-client-image-normalization.test.ts @@ -0,0 +1,57 @@ +import { describe, expect, test } from "bun:test"; +import { __listenClientTestUtils } from "../../websocket/listen-client"; + +describe("listen-client inbound image normalization", () => { + test("normalizes base64 image content through the shared resize path", async () => { + const resize = async (_buffer: Buffer, mediaType: string) => ({ + data: "resized-base64-image", + mediaType: mediaType === "image/png" ? "image/jpeg" : mediaType, + width: 1600, + height: 1200, + resized: true, + }); + + const normalized = await __listenClientTestUtils.normalizeInboundMessages( + [ + { + type: "message", + role: "user", + content: [ + { type: "text", text: "describe this" }, + { + type: "image", + source: { + type: "base64", + media_type: "image/png", + data: "raw-base64-image", + }, + }, + ], + client_message_id: "cm-image-1", + }, + ], + resize, + ); + + expect(normalized).toHaveLength(1); + const message = normalized[0]; + if (!message) { + throw new Error("Expected normalized message"); + } + expect("content" in message).toBe(true); + if (!("content" in message) || typeof message.content === "string") { + throw new Error("Expected multimodal content"); + } + expect(message.content).toEqual([ + { type: "text", text: "describe this" }, + { + type: "image", + source: { + type: "base64", + media_type: "image/jpeg", + data: "resized-base64-image", + }, + }, + ]); + }); +}); diff --git a/src/websocket/listen-client.ts b/src/websocket/listen-client.ts index 4f2c913..5ca9bf8 100644 --- a/src/websocket/listen-client.ts +++ b/src/websocket/listen-client.ts @@ -36,6 +36,7 @@ import { } from "../agent/turn-recovery-policy"; import { createBuffers } from "../cli/helpers/accumulator"; import { classifyApprovals } from "../cli/helpers/approvalClassification"; +import { resizeImageIfNeeded } from "../cli/helpers/imageResize"; import { generatePlanFilePath } from "../cli/helpers/planName"; import { drainStreamWithResume } from "../cli/helpers/stream"; import { INTERRUPTED_BY_USER } from "../constants"; @@ -256,6 +257,10 @@ interface RecoverPendingApprovalsMessage { conversationId?: string; } +type InboundMessagePayload = + | (MessageCreate & { client_message_id?: string }) + | ApprovalCreate; + interface StatusResponseMessage { type: "status_response"; currentMode: "default" | "acceptEdits" | "plan" | "bypassPermissions"; @@ -794,7 +799,7 @@ function loadPersistedCwdMap(): Map { try { const cachePath = getCwdCachePath(); if (!existsSync(cachePath)) return new Map(); - const raw = require("fs").readFileSync(cachePath, "utf-8") as string; + const raw = require("node:fs").readFileSync(cachePath, "utf-8") as string; const parsed = JSON.parse(raw) as Record; // Validate entries: only keep directories that still exist const map = new Map(); @@ -1071,6 +1076,105 @@ function mergeDequeuedBatchContent( }); } +function isBase64ImageContentPart(part: unknown): part is { + type: "image"; + source: { type: "base64"; media_type: string; data: string }; +} { + if (!part || typeof part !== "object") { + return false; + } + + const candidate = part as { + type?: unknown; + source?: { + type?: unknown; + media_type?: unknown; + data?: unknown; + }; + }; + + return ( + candidate.type === "image" && + !!candidate.source && + candidate.source.type === "base64" && + typeof candidate.source.media_type === "string" && + candidate.source.media_type.length > 0 && + typeof candidate.source.data === "string" && + candidate.source.data.length > 0 + ); +} + +async function normalizeMessageContentImages( + content: MessageCreate["content"], + resize: typeof resizeImageIfNeeded = resizeImageIfNeeded, +): Promise { + if (typeof content === "string") { + return content; + } + + let didChange = false; + const normalizedParts = await Promise.all( + content.map(async (part) => { + if (!isBase64ImageContentPart(part)) { + return part; + } + + const resized = await resize( + Buffer.from(part.source.data, "base64"), + part.source.media_type, + ); + if ( + resized.data !== part.source.data || + resized.mediaType !== part.source.media_type + ) { + didChange = true; + } + + return { + ...part, + source: { + ...part.source, + type: "base64" as const, + data: resized.data, + media_type: resized.mediaType, + }, + }; + }), + ); + + return didChange ? normalizedParts : content; +} + +async function normalizeInboundMessages( + messages: InboundMessagePayload[], + resize: typeof resizeImageIfNeeded = resizeImageIfNeeded, +): Promise { + let didChange = false; + + const normalizedMessages = await Promise.all( + messages.map(async (message) => { + if (!("content" in message)) { + return message; + } + + const normalizedContent = await normalizeMessageContentImages( + message.content, + resize, + ); + if (normalizedContent !== message.content) { + didChange = true; + return { + ...message, + content: normalizedContent, + }; + } + return message; + }), + ); + + return didChange ? normalizedMessages : messages; +} + function getPrimaryQueueMessageItem(items: QueueItem[]): QueueItem | null { for (const item of items) { if (item.kind === "message") { @@ -3168,6 +3272,7 @@ async function handleIncomingMessage( onStatusChange?.("processing", connectionId); } + const normalizedMessages = await normalizeInboundMessages(msg.messages); const messagesToSend: Array = []; let turnToolContextId: string | null = null; let queuedInterruptedToolCallIds: string[] = []; @@ -3183,9 +3288,9 @@ async function handleIncomingMessage( queuedInterruptedToolCallIds = consumed.interruptedToolCallIds; } - messagesToSend.push(...msg.messages); + messagesToSend.push(...normalizedMessages); - const firstMessage = msg.messages[0]; + const firstMessage = normalizedMessages[0]; const isApprovalMessage = firstMessage && "type" in firstMessage && @@ -3865,4 +3970,6 @@ export const __listenClientTestUtils = { normalizeToolReturnWireMessage, normalizeExecutionResultsForInterruptParity, shouldAttemptPostStopApprovalRecovery, + normalizeMessageContentImages, + normalizeInboundMessages, };