fix(discord): gate managed commands and harden attachment fetches (#549)
Co-authored-by: Letta Code <noreply@letta.com>
This commit is contained in:
@@ -5,6 +5,12 @@ import { Readable } from 'node:stream';
|
||||
import { pipeline } from 'node:stream/promises';
|
||||
|
||||
const SAFE_NAME_RE = /[^A-Za-z0-9._-]/g;
|
||||
const DEFAULT_DOWNLOAD_TIMEOUT_MS = 15000;
|
||||
|
||||
type DownloadToFileOptions = {
|
||||
headers?: Record<string, string>;
|
||||
timeoutMs?: number;
|
||||
};
|
||||
|
||||
export function sanitizeFilename(input: string): string {
|
||||
const cleaned = input.replace(SAFE_NAME_RE, '_').replace(/^_+|_+$/g, '');
|
||||
@@ -30,10 +36,14 @@ export function buildAttachmentPath(
|
||||
export async function downloadToFile(
|
||||
url: string,
|
||||
filePath: string,
|
||||
headers?: Record<string, string>
|
||||
options: DownloadToFileOptions = {}
|
||||
): Promise<void> {
|
||||
const { headers, timeoutMs = DEFAULT_DOWNLOAD_TIMEOUT_MS } = options;
|
||||
ensureParentDir(filePath);
|
||||
const res = await fetch(url, { headers });
|
||||
const res = await fetch(url, {
|
||||
headers,
|
||||
signal: AbortSignal.timeout(timeoutMs),
|
||||
});
|
||||
if (!res.ok || !res.body) {
|
||||
throw new Error(`Download failed (${res.status})`);
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ function makeMessage(params: {
|
||||
};
|
||||
member: { displayName: string };
|
||||
mentions: { has: () => boolean };
|
||||
attachments: { find: () => undefined; values: () => unknown[] };
|
||||
attachments: { find: (_predicate?: unknown) => unknown | undefined; values: () => unknown[] };
|
||||
createdAt: Date;
|
||||
reply: ReturnType<typeof vi.fn>;
|
||||
startThread: ReturnType<typeof vi.fn>;
|
||||
@@ -120,7 +120,7 @@ function makeMessage(params: {
|
||||
member: { displayName: 'Alice' },
|
||||
mentions: { has: () => false },
|
||||
attachments: {
|
||||
find: () => undefined,
|
||||
find: (_predicate?: unknown) => undefined,
|
||||
values: () => [],
|
||||
},
|
||||
createdAt: new Date(),
|
||||
@@ -131,9 +131,120 @@ function makeMessage(params: {
|
||||
|
||||
describe('DiscordAdapter command gating', () => {
|
||||
afterEach(async () => {
|
||||
vi.restoreAllMocks();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('does not download attachments for groups outside allowlist', async () => {
|
||||
const adapter = new DiscordAdapter({
|
||||
token: 'token',
|
||||
attachmentsDir: '/tmp/attachments',
|
||||
groups: {
|
||||
'channel-2': { mode: 'open' },
|
||||
},
|
||||
});
|
||||
const onMessage = vi.fn().mockResolvedValue(undefined);
|
||||
adapter.onMessage = onMessage;
|
||||
|
||||
const fetchSpy = vi.spyOn(globalThis, 'fetch');
|
||||
|
||||
await adapter.start();
|
||||
const client = discordMock.getLatestClient();
|
||||
expect(client).toBeTruthy();
|
||||
|
||||
const message = makeMessage({
|
||||
content: 'hello',
|
||||
isThread: false,
|
||||
channelId: 'channel-1',
|
||||
});
|
||||
message.attachments = {
|
||||
find: () => undefined,
|
||||
values: () => [{
|
||||
id: 'att-1',
|
||||
name: 'image.png',
|
||||
size: 123,
|
||||
url: 'https://cdn.example.com/image.png',
|
||||
}],
|
||||
};
|
||||
|
||||
await client!.emit('messageCreate', message);
|
||||
|
||||
expect(fetchSpy).not.toHaveBeenCalled();
|
||||
expect(onMessage).not.toHaveBeenCalled();
|
||||
await adapter.stop();
|
||||
});
|
||||
|
||||
it('does not fetch voice attachment audio for groups outside allowlist', async () => {
|
||||
const adapter = new DiscordAdapter({
|
||||
token: 'token',
|
||||
groups: {
|
||||
'channel-2': { mode: 'open' },
|
||||
},
|
||||
});
|
||||
const onMessage = vi.fn().mockResolvedValue(undefined);
|
||||
adapter.onMessage = onMessage;
|
||||
|
||||
const fetchSpy = vi.spyOn(globalThis, 'fetch');
|
||||
|
||||
await adapter.start();
|
||||
const client = discordMock.getLatestClient();
|
||||
expect(client).toBeTruthy();
|
||||
|
||||
const message = makeMessage({
|
||||
content: '',
|
||||
isThread: false,
|
||||
channelId: 'channel-1',
|
||||
});
|
||||
message.attachments = {
|
||||
find: () => ({
|
||||
contentType: 'audio/ogg',
|
||||
name: 'voice.ogg',
|
||||
url: 'https://cdn.example.com/voice.ogg',
|
||||
}),
|
||||
values: () => [{
|
||||
id: 'att-audio-1',
|
||||
contentType: 'audio/ogg',
|
||||
name: 'voice.ogg',
|
||||
size: 321,
|
||||
url: 'https://cdn.example.com/voice.ogg',
|
||||
}],
|
||||
};
|
||||
|
||||
await client!.emit('messageCreate', message);
|
||||
|
||||
expect(fetchSpy).not.toHaveBeenCalled();
|
||||
expect(onMessage).not.toHaveBeenCalled();
|
||||
expect(message.reply).not.toHaveBeenCalled();
|
||||
await adapter.stop();
|
||||
});
|
||||
|
||||
it('blocks managed slash commands for groups outside allowlist', async () => {
|
||||
const adapter = new DiscordAdapter({
|
||||
token: 'token',
|
||||
groups: {
|
||||
'channel-2': { mode: 'open' },
|
||||
},
|
||||
});
|
||||
const onCommand = vi.fn().mockResolvedValue('ok');
|
||||
adapter.onCommand = onCommand;
|
||||
|
||||
await adapter.start();
|
||||
const client = discordMock.getLatestClient();
|
||||
expect(client).toBeTruthy();
|
||||
|
||||
const message = makeMessage({
|
||||
content: '/status',
|
||||
isThread: false,
|
||||
channelId: 'channel-1',
|
||||
});
|
||||
|
||||
await client!.emit('messageCreate', message);
|
||||
|
||||
expect(onCommand).not.toHaveBeenCalled();
|
||||
expect(message.channel.send).not.toHaveBeenCalled();
|
||||
await adapter.stop();
|
||||
});
|
||||
|
||||
it('blocks top-level slash commands when threadMode is thread-only', async () => {
|
||||
const adapter = new DiscordAdapter({
|
||||
token: 'token',
|
||||
|
||||
@@ -20,6 +20,7 @@ import { basename } from 'node:path';
|
||||
import { createLogger } from '../logger.js';
|
||||
|
||||
const log = createLogger('Discord');
|
||||
const DISCORD_ATTACHMENT_DOWNLOAD_TIMEOUT_MS = 15000;
|
||||
// Dynamic import to avoid requiring Discord deps if not used
|
||||
let Client: typeof import('discord.js').Client;
|
||||
let GatewayIntentBits: typeof import('discord.js').GatewayIntentBits;
|
||||
@@ -244,36 +245,6 @@ Ask the bot owner to approve with:
|
||||
const userId = message.author?.id;
|
||||
if (!userId) return;
|
||||
|
||||
// Handle audio attachments
|
||||
const audioAttachment = message.attachments.find(a => a.contentType?.startsWith('audio/'));
|
||||
if (audioAttachment?.url) {
|
||||
try {
|
||||
const { isTranscriptionConfigured } = await import('../transcription/index.js');
|
||||
if (!isTranscriptionConfigured()) {
|
||||
await message.reply('Voice messages require a transcription API key. See: https://github.com/letta-ai/lettabot#voice');
|
||||
} else {
|
||||
// Download audio
|
||||
const response = await fetch(audioAttachment.url);
|
||||
const buffer = Buffer.from(await response.arrayBuffer());
|
||||
|
||||
const { transcribeAudio } = await import('../transcription/index.js');
|
||||
const ext = audioAttachment.contentType?.split('/')[1] || 'mp3';
|
||||
const result = await transcribeAudio(buffer, audioAttachment.name || `audio.${ext}`);
|
||||
|
||||
if (result.success && result.text) {
|
||||
log.info(`Transcribed audio: "${result.text.slice(0, 50)}..."`);
|
||||
content = (content ? content + '\n' : '') + `[Voice message]: ${result.text}`;
|
||||
} else {
|
||||
log.error(`Transcription failed: ${result.error}`);
|
||||
content = (content ? content + '\n' : '') + `[Voice message - transcription failed: ${result.error}]`;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
log.error('Error transcribing audio:', error);
|
||||
content = (content ? content + '\n' : '') + `[Voice message - error: ${error instanceof Error ? error.message : 'unknown error'}]`;
|
||||
}
|
||||
}
|
||||
|
||||
// Bypass pairing for guild (group) messages
|
||||
if (!message.guildId) {
|
||||
const access = await this.checkAccess(userId);
|
||||
@@ -306,9 +277,6 @@ Ask the bot owner to approve with:
|
||||
}
|
||||
}
|
||||
|
||||
const attachments = await this.collectAttachments(message.attachments, message.channel.id);
|
||||
if (!content && attachments.length === 0) return;
|
||||
|
||||
if (content.startsWith('/')) {
|
||||
const parts = content.slice(1).split(/\s+/);
|
||||
const command = parts[0]?.toLowerCase();
|
||||
@@ -324,6 +292,23 @@ Ask the bot owner to approve with:
|
||||
|
||||
// Unknown commands (or managed commands without onCommand) fall through to agent processing.
|
||||
if (isHelpCommand || (isManagedCommand && this.onCommand)) {
|
||||
if (isGroup && this.config.groups && !isHelpCommand) {
|
||||
if (!isGroupAllowed(this.config.groups, keys)) {
|
||||
log.info(`Group ${chatId} not in allowlist, ignoring command`);
|
||||
return;
|
||||
}
|
||||
if (!isGroupUserAllowed(this.config.groups, keys, userId)) {
|
||||
return;
|
||||
}
|
||||
const mode = resolveGroupMode(this.config.groups, keys, 'open');
|
||||
if (mode === 'disabled') {
|
||||
return;
|
||||
}
|
||||
if (mode === 'mention-only' && !wasMentioned) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let commandChatId = message.channel.id;
|
||||
let commandSendTarget: { send: (content: string) => Promise<unknown> } | null =
|
||||
message.channel.isTextBased() && 'send' in message.channel
|
||||
@@ -434,6 +419,37 @@ Ask the bot owner to approve with:
|
||||
}
|
||||
}
|
||||
|
||||
const audioAttachment = message.attachments.find((a) => a.contentType?.startsWith('audio/'));
|
||||
if (audioAttachment?.url) {
|
||||
try {
|
||||
const { isTranscriptionConfigured } = await import('../transcription/index.js');
|
||||
if (!isTranscriptionConfigured()) {
|
||||
await message.reply('Voice messages require a transcription API key. See: https://github.com/letta-ai/lettabot#voice');
|
||||
} else {
|
||||
const response = await fetch(audioAttachment.url);
|
||||
const buffer = Buffer.from(await response.arrayBuffer());
|
||||
|
||||
const { transcribeAudio } = await import('../transcription/index.js');
|
||||
const ext = audioAttachment.contentType?.split('/')[1] || 'mp3';
|
||||
const result = await transcribeAudio(buffer, audioAttachment.name || `audio.${ext}`);
|
||||
|
||||
if (result.success && result.text) {
|
||||
log.info(`Transcribed audio: "${result.text.slice(0, 50)}..."`);
|
||||
content = (content ? content + '\n' : '') + `[Voice message]: ${result.text}`;
|
||||
} else {
|
||||
log.error(`Transcription failed: ${result.error}`);
|
||||
content = (content ? content + '\n' : '') + `[Voice message - transcription failed: ${result.error}]`;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
log.error('Error transcribing audio:', error);
|
||||
content = (content ? content + '\n' : '') + `[Voice message - error: ${error instanceof Error ? error.message : 'unknown error'}]`;
|
||||
}
|
||||
}
|
||||
|
||||
const attachments = await this.collectAttachments(message.attachments, message.channel.id);
|
||||
if (!content && attachments.length === 0) return;
|
||||
|
||||
await this.onMessage({
|
||||
channel: 'discord',
|
||||
chatId: effectiveChatId,
|
||||
@@ -720,7 +736,9 @@ Ask the bot owner to approve with:
|
||||
}
|
||||
const target = buildAttachmentPath(this.attachmentsDir, 'discord', channelId, name);
|
||||
try {
|
||||
await downloadToFile(attachment.url, target);
|
||||
await downloadToFile(attachment.url, target, {
|
||||
timeoutMs: DISCORD_ATTACHMENT_DOWNLOAD_TIMEOUT_MS,
|
||||
});
|
||||
entry.localPath = target;
|
||||
log.info(`Attachment saved to ${target}`);
|
||||
} catch (err) {
|
||||
|
||||
@@ -512,7 +512,9 @@ async function maybeDownloadSlackFile(
|
||||
}
|
||||
const target = buildAttachmentPath(attachmentsDir, 'slack', channelId, name);
|
||||
try {
|
||||
await downloadToFile(url, target, { Authorization: `Bearer ${token}` });
|
||||
await downloadToFile(url, target, {
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
});
|
||||
attachment.localPath = target;
|
||||
log.info(`Attachment saved to ${target}`);
|
||||
} catch (err) {
|
||||
|
||||
Reference in New Issue
Block a user