fix(discord): gate managed commands and harden attachment fetches (#549)

Co-authored-by: Letta Code <noreply@letta.com>
This commit is contained in:
Cameron
2026-03-10 17:03:37 -07:00
committed by GitHub
parent ef63efc892
commit ef1504bd9a
4 changed files with 180 additions and 39 deletions

View File

@@ -5,6 +5,12 @@ import { Readable } from 'node:stream';
import { pipeline } from 'node:stream/promises'; import { pipeline } from 'node:stream/promises';
const SAFE_NAME_RE = /[^A-Za-z0-9._-]/g; const SAFE_NAME_RE = /[^A-Za-z0-9._-]/g;
const DEFAULT_DOWNLOAD_TIMEOUT_MS = 15000;
type DownloadToFileOptions = {
headers?: Record<string, string>;
timeoutMs?: number;
};
export function sanitizeFilename(input: string): string { export function sanitizeFilename(input: string): string {
const cleaned = input.replace(SAFE_NAME_RE, '_').replace(/^_+|_+$/g, ''); const cleaned = input.replace(SAFE_NAME_RE, '_').replace(/^_+|_+$/g, '');
@@ -30,10 +36,14 @@ export function buildAttachmentPath(
export async function downloadToFile( export async function downloadToFile(
url: string, url: string,
filePath: string, filePath: string,
headers?: Record<string, string> options: DownloadToFileOptions = {}
): Promise<void> { ): Promise<void> {
const { headers, timeoutMs = DEFAULT_DOWNLOAD_TIMEOUT_MS } = options;
ensureParentDir(filePath); ensureParentDir(filePath);
const res = await fetch(url, { headers }); const res = await fetch(url, {
headers,
signal: AbortSignal.timeout(timeoutMs),
});
if (!res.ok || !res.body) { if (!res.ok || !res.body) {
throw new Error(`Download failed (${res.status})`); throw new Error(`Download failed (${res.status})`);
} }

View File

@@ -93,7 +93,7 @@ function makeMessage(params: {
}; };
member: { displayName: string }; member: { displayName: string };
mentions: { has: () => boolean }; mentions: { has: () => boolean };
attachments: { find: () => undefined; values: () => unknown[] }; attachments: { find: (_predicate?: unknown) => unknown | undefined; values: () => unknown[] };
createdAt: Date; createdAt: Date;
reply: ReturnType<typeof vi.fn>; reply: ReturnType<typeof vi.fn>;
startThread: ReturnType<typeof vi.fn>; startThread: ReturnType<typeof vi.fn>;
@@ -120,7 +120,7 @@ function makeMessage(params: {
member: { displayName: 'Alice' }, member: { displayName: 'Alice' },
mentions: { has: () => false }, mentions: { has: () => false },
attachments: { attachments: {
find: () => undefined, find: (_predicate?: unknown) => undefined,
values: () => [], values: () => [],
}, },
createdAt: new Date(), createdAt: new Date(),
@@ -131,9 +131,120 @@ function makeMessage(params: {
describe('DiscordAdapter command gating', () => { describe('DiscordAdapter command gating', () => {
afterEach(async () => { afterEach(async () => {
vi.restoreAllMocks();
vi.clearAllMocks(); vi.clearAllMocks();
}); });
it('does not download attachments for groups outside allowlist', async () => {
const adapter = new DiscordAdapter({
token: 'token',
attachmentsDir: '/tmp/attachments',
groups: {
'channel-2': { mode: 'open' },
},
});
const onMessage = vi.fn().mockResolvedValue(undefined);
adapter.onMessage = onMessage;
const fetchSpy = vi.spyOn(globalThis, 'fetch');
await adapter.start();
const client = discordMock.getLatestClient();
expect(client).toBeTruthy();
const message = makeMessage({
content: 'hello',
isThread: false,
channelId: 'channel-1',
});
message.attachments = {
find: () => undefined,
values: () => [{
id: 'att-1',
name: 'image.png',
size: 123,
url: 'https://cdn.example.com/image.png',
}],
};
await client!.emit('messageCreate', message);
expect(fetchSpy).not.toHaveBeenCalled();
expect(onMessage).not.toHaveBeenCalled();
await adapter.stop();
});
it('does not fetch voice attachment audio for groups outside allowlist', async () => {
const adapter = new DiscordAdapter({
token: 'token',
groups: {
'channel-2': { mode: 'open' },
},
});
const onMessage = vi.fn().mockResolvedValue(undefined);
adapter.onMessage = onMessage;
const fetchSpy = vi.spyOn(globalThis, 'fetch');
await adapter.start();
const client = discordMock.getLatestClient();
expect(client).toBeTruthy();
const message = makeMessage({
content: '',
isThread: false,
channelId: 'channel-1',
});
message.attachments = {
find: () => ({
contentType: 'audio/ogg',
name: 'voice.ogg',
url: 'https://cdn.example.com/voice.ogg',
}),
values: () => [{
id: 'att-audio-1',
contentType: 'audio/ogg',
name: 'voice.ogg',
size: 321,
url: 'https://cdn.example.com/voice.ogg',
}],
};
await client!.emit('messageCreate', message);
expect(fetchSpy).not.toHaveBeenCalled();
expect(onMessage).not.toHaveBeenCalled();
expect(message.reply).not.toHaveBeenCalled();
await adapter.stop();
});
it('blocks managed slash commands for groups outside allowlist', async () => {
const adapter = new DiscordAdapter({
token: 'token',
groups: {
'channel-2': { mode: 'open' },
},
});
const onCommand = vi.fn().mockResolvedValue('ok');
adapter.onCommand = onCommand;
await adapter.start();
const client = discordMock.getLatestClient();
expect(client).toBeTruthy();
const message = makeMessage({
content: '/status',
isThread: false,
channelId: 'channel-1',
});
await client!.emit('messageCreate', message);
expect(onCommand).not.toHaveBeenCalled();
expect(message.channel.send).not.toHaveBeenCalled();
await adapter.stop();
});
it('blocks top-level slash commands when threadMode is thread-only', async () => { it('blocks top-level slash commands when threadMode is thread-only', async () => {
const adapter = new DiscordAdapter({ const adapter = new DiscordAdapter({
token: 'token', token: 'token',

View File

@@ -20,6 +20,7 @@ import { basename } from 'node:path';
import { createLogger } from '../logger.js'; import { createLogger } from '../logger.js';
const log = createLogger('Discord'); const log = createLogger('Discord');
const DISCORD_ATTACHMENT_DOWNLOAD_TIMEOUT_MS = 15000;
// Dynamic import to avoid requiring Discord deps if not used // Dynamic import to avoid requiring Discord deps if not used
let Client: typeof import('discord.js').Client; let Client: typeof import('discord.js').Client;
let GatewayIntentBits: typeof import('discord.js').GatewayIntentBits; let GatewayIntentBits: typeof import('discord.js').GatewayIntentBits;
@@ -243,36 +244,6 @@ Ask the bot owner to approve with:
let content = (message.content || '').trim(); let content = (message.content || '').trim();
const userId = message.author?.id; const userId = message.author?.id;
if (!userId) return; if (!userId) return;
// Handle audio attachments
const audioAttachment = message.attachments.find(a => a.contentType?.startsWith('audio/'));
if (audioAttachment?.url) {
try {
const { isTranscriptionConfigured } = await import('../transcription/index.js');
if (!isTranscriptionConfigured()) {
await message.reply('Voice messages require a transcription API key. See: https://github.com/letta-ai/lettabot#voice');
} else {
// Download audio
const response = await fetch(audioAttachment.url);
const buffer = Buffer.from(await response.arrayBuffer());
const { transcribeAudio } = await import('../transcription/index.js');
const ext = audioAttachment.contentType?.split('/')[1] || 'mp3';
const result = await transcribeAudio(buffer, audioAttachment.name || `audio.${ext}`);
if (result.success && result.text) {
log.info(`Transcribed audio: "${result.text.slice(0, 50)}..."`);
content = (content ? content + '\n' : '') + `[Voice message]: ${result.text}`;
} else {
log.error(`Transcription failed: ${result.error}`);
content = (content ? content + '\n' : '') + `[Voice message - transcription failed: ${result.error}]`;
}
}
} catch (error) {
log.error('Error transcribing audio:', error);
content = (content ? content + '\n' : '') + `[Voice message - error: ${error instanceof Error ? error.message : 'unknown error'}]`;
}
}
// Bypass pairing for guild (group) messages // Bypass pairing for guild (group) messages
if (!message.guildId) { if (!message.guildId) {
@@ -306,9 +277,6 @@ Ask the bot owner to approve with:
} }
} }
const attachments = await this.collectAttachments(message.attachments, message.channel.id);
if (!content && attachments.length === 0) return;
if (content.startsWith('/')) { if (content.startsWith('/')) {
const parts = content.slice(1).split(/\s+/); const parts = content.slice(1).split(/\s+/);
const command = parts[0]?.toLowerCase(); const command = parts[0]?.toLowerCase();
@@ -324,6 +292,23 @@ Ask the bot owner to approve with:
// Unknown commands (or managed commands without onCommand) fall through to agent processing. // Unknown commands (or managed commands without onCommand) fall through to agent processing.
if (isHelpCommand || (isManagedCommand && this.onCommand)) { if (isHelpCommand || (isManagedCommand && this.onCommand)) {
if (isGroup && this.config.groups && !isHelpCommand) {
if (!isGroupAllowed(this.config.groups, keys)) {
log.info(`Group ${chatId} not in allowlist, ignoring command`);
return;
}
if (!isGroupUserAllowed(this.config.groups, keys, userId)) {
return;
}
const mode = resolveGroupMode(this.config.groups, keys, 'open');
if (mode === 'disabled') {
return;
}
if (mode === 'mention-only' && !wasMentioned) {
return;
}
}
let commandChatId = message.channel.id; let commandChatId = message.channel.id;
let commandSendTarget: { send: (content: string) => Promise<unknown> } | null = let commandSendTarget: { send: (content: string) => Promise<unknown> } | null =
message.channel.isTextBased() && 'send' in message.channel message.channel.isTextBased() && 'send' in message.channel
@@ -434,6 +419,37 @@ Ask the bot owner to approve with:
} }
} }
const audioAttachment = message.attachments.find((a) => a.contentType?.startsWith('audio/'));
if (audioAttachment?.url) {
try {
const { isTranscriptionConfigured } = await import('../transcription/index.js');
if (!isTranscriptionConfigured()) {
await message.reply('Voice messages require a transcription API key. See: https://github.com/letta-ai/lettabot#voice');
} else {
const response = await fetch(audioAttachment.url);
const buffer = Buffer.from(await response.arrayBuffer());
const { transcribeAudio } = await import('../transcription/index.js');
const ext = audioAttachment.contentType?.split('/')[1] || 'mp3';
const result = await transcribeAudio(buffer, audioAttachment.name || `audio.${ext}`);
if (result.success && result.text) {
log.info(`Transcribed audio: "${result.text.slice(0, 50)}..."`);
content = (content ? content + '\n' : '') + `[Voice message]: ${result.text}`;
} else {
log.error(`Transcription failed: ${result.error}`);
content = (content ? content + '\n' : '') + `[Voice message - transcription failed: ${result.error}]`;
}
}
} catch (error) {
log.error('Error transcribing audio:', error);
content = (content ? content + '\n' : '') + `[Voice message - error: ${error instanceof Error ? error.message : 'unknown error'}]`;
}
}
const attachments = await this.collectAttachments(message.attachments, message.channel.id);
if (!content && attachments.length === 0) return;
await this.onMessage({ await this.onMessage({
channel: 'discord', channel: 'discord',
chatId: effectiveChatId, chatId: effectiveChatId,
@@ -720,7 +736,9 @@ Ask the bot owner to approve with:
} }
const target = buildAttachmentPath(this.attachmentsDir, 'discord', channelId, name); const target = buildAttachmentPath(this.attachmentsDir, 'discord', channelId, name);
try { try {
await downloadToFile(attachment.url, target); await downloadToFile(attachment.url, target, {
timeoutMs: DISCORD_ATTACHMENT_DOWNLOAD_TIMEOUT_MS,
});
entry.localPath = target; entry.localPath = target;
log.info(`Attachment saved to ${target}`); log.info(`Attachment saved to ${target}`);
} catch (err) { } catch (err) {

View File

@@ -512,7 +512,9 @@ async function maybeDownloadSlackFile(
} }
const target = buildAttachmentPath(attachmentsDir, 'slack', channelId, name); const target = buildAttachmentPath(attachmentsDir, 'slack', channelId, name);
try { try {
await downloadToFile(url, target, { Authorization: `Bearer ${token}` }); await downloadToFile(url, target, {
headers: { Authorization: `Bearer ${token}` },
});
attachment.localPath = target; attachment.localPath = target;
log.info(`Attachment saved to ${target}`); log.info(`Attachment saved to ${target}`);
} catch (err) { } catch (err) {