From ef1504bd9a5f268fadbdc4aeed5b1ca72593f5c3 Mon Sep 17 00:00:00 2001 From: Cameron Date: Tue, 10 Mar 2026 17:03:37 -0700 Subject: [PATCH] fix(discord): gate managed commands and harden attachment fetches (#549) Co-authored-by: Letta Code --- src/channels/attachments.ts | 14 +++- src/channels/discord-adapter.test.ts | 115 ++++++++++++++++++++++++++- src/channels/discord.ts | 86 ++++++++++++-------- src/channels/slack.ts | 4 +- 4 files changed, 180 insertions(+), 39 deletions(-) diff --git a/src/channels/attachments.ts b/src/channels/attachments.ts index c2d6941..02c95f3 100644 --- a/src/channels/attachments.ts +++ b/src/channels/attachments.ts @@ -5,6 +5,12 @@ import { Readable } from 'node:stream'; import { pipeline } from 'node:stream/promises'; const SAFE_NAME_RE = /[^A-Za-z0-9._-]/g; +const DEFAULT_DOWNLOAD_TIMEOUT_MS = 15000; + +type DownloadToFileOptions = { + headers?: Record; + timeoutMs?: number; +}; export function sanitizeFilename(input: string): string { const cleaned = input.replace(SAFE_NAME_RE, '_').replace(/^_+|_+$/g, ''); @@ -30,10 +36,14 @@ export function buildAttachmentPath( export async function downloadToFile( url: string, filePath: string, - headers?: Record + options: DownloadToFileOptions = {} ): Promise { + const { headers, timeoutMs = DEFAULT_DOWNLOAD_TIMEOUT_MS } = options; ensureParentDir(filePath); - const res = await fetch(url, { headers }); + const res = await fetch(url, { + headers, + signal: AbortSignal.timeout(timeoutMs), + }); if (!res.ok || !res.body) { throw new Error(`Download failed (${res.status})`); } diff --git a/src/channels/discord-adapter.test.ts b/src/channels/discord-adapter.test.ts index be34579..20d9c01 100644 --- a/src/channels/discord-adapter.test.ts +++ b/src/channels/discord-adapter.test.ts @@ -93,7 +93,7 @@ function makeMessage(params: { }; member: { displayName: string }; mentions: { has: () => boolean }; - attachments: { find: () => undefined; values: () => unknown[] }; + attachments: { find: (_predicate?: unknown) => unknown | undefined; values: () => unknown[] }; createdAt: Date; reply: ReturnType; startThread: ReturnType; @@ -120,7 +120,7 @@ function makeMessage(params: { member: { displayName: 'Alice' }, mentions: { has: () => false }, attachments: { - find: () => undefined, + find: (_predicate?: unknown) => undefined, values: () => [], }, createdAt: new Date(), @@ -131,9 +131,120 @@ function makeMessage(params: { describe('DiscordAdapter command gating', () => { afterEach(async () => { + vi.restoreAllMocks(); vi.clearAllMocks(); }); + it('does not download attachments for groups outside allowlist', async () => { + const adapter = new DiscordAdapter({ + token: 'token', + attachmentsDir: '/tmp/attachments', + groups: { + 'channel-2': { mode: 'open' }, + }, + }); + const onMessage = vi.fn().mockResolvedValue(undefined); + adapter.onMessage = onMessage; + + const fetchSpy = vi.spyOn(globalThis, 'fetch'); + + await adapter.start(); + const client = discordMock.getLatestClient(); + expect(client).toBeTruthy(); + + const message = makeMessage({ + content: 'hello', + isThread: false, + channelId: 'channel-1', + }); + message.attachments = { + find: () => undefined, + values: () => [{ + id: 'att-1', + name: 'image.png', + size: 123, + url: 'https://cdn.example.com/image.png', + }], + }; + + await client!.emit('messageCreate', message); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(onMessage).not.toHaveBeenCalled(); + await adapter.stop(); + }); + + it('does not fetch voice attachment audio for groups outside allowlist', async () => { + const adapter = new DiscordAdapter({ + token: 'token', + groups: { + 'channel-2': { mode: 'open' }, + }, + }); + const onMessage = vi.fn().mockResolvedValue(undefined); + adapter.onMessage = onMessage; + + const fetchSpy = vi.spyOn(globalThis, 'fetch'); + + await adapter.start(); + const client = discordMock.getLatestClient(); + expect(client).toBeTruthy(); + + const message = makeMessage({ + content: '', + isThread: false, + channelId: 'channel-1', + }); + message.attachments = { + find: () => ({ + contentType: 'audio/ogg', + name: 'voice.ogg', + url: 'https://cdn.example.com/voice.ogg', + }), + values: () => [{ + id: 'att-audio-1', + contentType: 'audio/ogg', + name: 'voice.ogg', + size: 321, + url: 'https://cdn.example.com/voice.ogg', + }], + }; + + await client!.emit('messageCreate', message); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(onMessage).not.toHaveBeenCalled(); + expect(message.reply).not.toHaveBeenCalled(); + await adapter.stop(); + }); + + it('blocks managed slash commands for groups outside allowlist', async () => { + const adapter = new DiscordAdapter({ + token: 'token', + groups: { + 'channel-2': { mode: 'open' }, + }, + }); + const onCommand = vi.fn().mockResolvedValue('ok'); + adapter.onCommand = onCommand; + + await adapter.start(); + const client = discordMock.getLatestClient(); + expect(client).toBeTruthy(); + + const message = makeMessage({ + content: '/status', + isThread: false, + channelId: 'channel-1', + }); + + await client!.emit('messageCreate', message); + + expect(onCommand).not.toHaveBeenCalled(); + expect(message.channel.send).not.toHaveBeenCalled(); + await adapter.stop(); + }); + it('blocks top-level slash commands when threadMode is thread-only', async () => { const adapter = new DiscordAdapter({ token: 'token', diff --git a/src/channels/discord.ts b/src/channels/discord.ts index e067a6a..bd31d46 100644 --- a/src/channels/discord.ts +++ b/src/channels/discord.ts @@ -20,6 +20,7 @@ import { basename } from 'node:path'; import { createLogger } from '../logger.js'; const log = createLogger('Discord'); +const DISCORD_ATTACHMENT_DOWNLOAD_TIMEOUT_MS = 15000; // Dynamic import to avoid requiring Discord deps if not used let Client: typeof import('discord.js').Client; let GatewayIntentBits: typeof import('discord.js').GatewayIntentBits; @@ -243,36 +244,6 @@ Ask the bot owner to approve with: let content = (message.content || '').trim(); const userId = message.author?.id; if (!userId) return; - - // Handle audio attachments - const audioAttachment = message.attachments.find(a => a.contentType?.startsWith('audio/')); - if (audioAttachment?.url) { - try { - const { isTranscriptionConfigured } = await import('../transcription/index.js'); - if (!isTranscriptionConfigured()) { - await message.reply('Voice messages require a transcription API key. See: https://github.com/letta-ai/lettabot#voice'); - } else { - // Download audio - const response = await fetch(audioAttachment.url); - const buffer = Buffer.from(await response.arrayBuffer()); - - const { transcribeAudio } = await import('../transcription/index.js'); - const ext = audioAttachment.contentType?.split('/')[1] || 'mp3'; - const result = await transcribeAudio(buffer, audioAttachment.name || `audio.${ext}`); - - if (result.success && result.text) { - log.info(`Transcribed audio: "${result.text.slice(0, 50)}..."`); - content = (content ? content + '\n' : '') + `[Voice message]: ${result.text}`; - } else { - log.error(`Transcription failed: ${result.error}`); - content = (content ? content + '\n' : '') + `[Voice message - transcription failed: ${result.error}]`; - } - } - } catch (error) { - log.error('Error transcribing audio:', error); - content = (content ? content + '\n' : '') + `[Voice message - error: ${error instanceof Error ? error.message : 'unknown error'}]`; - } - } // Bypass pairing for guild (group) messages if (!message.guildId) { @@ -306,9 +277,6 @@ Ask the bot owner to approve with: } } - const attachments = await this.collectAttachments(message.attachments, message.channel.id); - if (!content && attachments.length === 0) return; - if (content.startsWith('/')) { const parts = content.slice(1).split(/\s+/); const command = parts[0]?.toLowerCase(); @@ -324,6 +292,23 @@ Ask the bot owner to approve with: // Unknown commands (or managed commands without onCommand) fall through to agent processing. if (isHelpCommand || (isManagedCommand && this.onCommand)) { + if (isGroup && this.config.groups && !isHelpCommand) { + if (!isGroupAllowed(this.config.groups, keys)) { + log.info(`Group ${chatId} not in allowlist, ignoring command`); + return; + } + if (!isGroupUserAllowed(this.config.groups, keys, userId)) { + return; + } + const mode = resolveGroupMode(this.config.groups, keys, 'open'); + if (mode === 'disabled') { + return; + } + if (mode === 'mention-only' && !wasMentioned) { + return; + } + } + let commandChatId = message.channel.id; let commandSendTarget: { send: (content: string) => Promise } | null = message.channel.isTextBased() && 'send' in message.channel @@ -434,6 +419,37 @@ Ask the bot owner to approve with: } } + const audioAttachment = message.attachments.find((a) => a.contentType?.startsWith('audio/')); + if (audioAttachment?.url) { + try { + const { isTranscriptionConfigured } = await import('../transcription/index.js'); + if (!isTranscriptionConfigured()) { + await message.reply('Voice messages require a transcription API key. See: https://github.com/letta-ai/lettabot#voice'); + } else { + const response = await fetch(audioAttachment.url); + const buffer = Buffer.from(await response.arrayBuffer()); + + const { transcribeAudio } = await import('../transcription/index.js'); + const ext = audioAttachment.contentType?.split('/')[1] || 'mp3'; + const result = await transcribeAudio(buffer, audioAttachment.name || `audio.${ext}`); + + if (result.success && result.text) { + log.info(`Transcribed audio: "${result.text.slice(0, 50)}..."`); + content = (content ? content + '\n' : '') + `[Voice message]: ${result.text}`; + } else { + log.error(`Transcription failed: ${result.error}`); + content = (content ? content + '\n' : '') + `[Voice message - transcription failed: ${result.error}]`; + } + } + } catch (error) { + log.error('Error transcribing audio:', error); + content = (content ? content + '\n' : '') + `[Voice message - error: ${error instanceof Error ? error.message : 'unknown error'}]`; + } + } + + const attachments = await this.collectAttachments(message.attachments, message.channel.id); + if (!content && attachments.length === 0) return; + await this.onMessage({ channel: 'discord', chatId: effectiveChatId, @@ -720,7 +736,9 @@ Ask the bot owner to approve with: } const target = buildAttachmentPath(this.attachmentsDir, 'discord', channelId, name); try { - await downloadToFile(attachment.url, target); + await downloadToFile(attachment.url, target, { + timeoutMs: DISCORD_ATTACHMENT_DOWNLOAD_TIMEOUT_MS, + }); entry.localPath = target; log.info(`Attachment saved to ${target}`); } catch (err) { diff --git a/src/channels/slack.ts b/src/channels/slack.ts index 0ef3ad9..8fb2336 100644 --- a/src/channels/slack.ts +++ b/src/channels/slack.ts @@ -512,7 +512,9 @@ async function maybeDownloadSlackFile( } const target = buildAttachmentPath(attachmentsDir, 'slack', channelId, name); try { - await downloadToFile(url, target, { Authorization: `Bearer ${token}` }); + await downloadToFile(url, target, { + headers: { Authorization: `Bearer ${token}` }, + }); attachment.localPath = target; log.info(`Attachment saved to ${target}`); } catch (err) {