fix(discord): gate managed commands and harden attachment fetches (#549)
Co-authored-by: Letta Code <noreply@letta.com>
This commit is contained in:
@@ -5,6 +5,12 @@ import { Readable } from 'node:stream';
|
|||||||
import { pipeline } from 'node:stream/promises';
|
import { pipeline } from 'node:stream/promises';
|
||||||
|
|
||||||
const SAFE_NAME_RE = /[^A-Za-z0-9._-]/g;
|
const SAFE_NAME_RE = /[^A-Za-z0-9._-]/g;
|
||||||
|
const DEFAULT_DOWNLOAD_TIMEOUT_MS = 15000;
|
||||||
|
|
||||||
|
type DownloadToFileOptions = {
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
timeoutMs?: number;
|
||||||
|
};
|
||||||
|
|
||||||
export function sanitizeFilename(input: string): string {
|
export function sanitizeFilename(input: string): string {
|
||||||
const cleaned = input.replace(SAFE_NAME_RE, '_').replace(/^_+|_+$/g, '');
|
const cleaned = input.replace(SAFE_NAME_RE, '_').replace(/^_+|_+$/g, '');
|
||||||
@@ -30,10 +36,14 @@ export function buildAttachmentPath(
|
|||||||
export async function downloadToFile(
|
export async function downloadToFile(
|
||||||
url: string,
|
url: string,
|
||||||
filePath: string,
|
filePath: string,
|
||||||
headers?: Record<string, string>
|
options: DownloadToFileOptions = {}
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
|
const { headers, timeoutMs = DEFAULT_DOWNLOAD_TIMEOUT_MS } = options;
|
||||||
ensureParentDir(filePath);
|
ensureParentDir(filePath);
|
||||||
const res = await fetch(url, { headers });
|
const res = await fetch(url, {
|
||||||
|
headers,
|
||||||
|
signal: AbortSignal.timeout(timeoutMs),
|
||||||
|
});
|
||||||
if (!res.ok || !res.body) {
|
if (!res.ok || !res.body) {
|
||||||
throw new Error(`Download failed (${res.status})`);
|
throw new Error(`Download failed (${res.status})`);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ function makeMessage(params: {
|
|||||||
};
|
};
|
||||||
member: { displayName: string };
|
member: { displayName: string };
|
||||||
mentions: { has: () => boolean };
|
mentions: { has: () => boolean };
|
||||||
attachments: { find: () => undefined; values: () => unknown[] };
|
attachments: { find: (_predicate?: unknown) => unknown | undefined; values: () => unknown[] };
|
||||||
createdAt: Date;
|
createdAt: Date;
|
||||||
reply: ReturnType<typeof vi.fn>;
|
reply: ReturnType<typeof vi.fn>;
|
||||||
startThread: ReturnType<typeof vi.fn>;
|
startThread: ReturnType<typeof vi.fn>;
|
||||||
@@ -120,7 +120,7 @@ function makeMessage(params: {
|
|||||||
member: { displayName: 'Alice' },
|
member: { displayName: 'Alice' },
|
||||||
mentions: { has: () => false },
|
mentions: { has: () => false },
|
||||||
attachments: {
|
attachments: {
|
||||||
find: () => undefined,
|
find: (_predicate?: unknown) => undefined,
|
||||||
values: () => [],
|
values: () => [],
|
||||||
},
|
},
|
||||||
createdAt: new Date(),
|
createdAt: new Date(),
|
||||||
@@ -131,9 +131,120 @@ function makeMessage(params: {
|
|||||||
|
|
||||||
describe('DiscordAdapter command gating', () => {
|
describe('DiscordAdapter command gating', () => {
|
||||||
afterEach(async () => {
|
afterEach(async () => {
|
||||||
|
vi.restoreAllMocks();
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('does not download attachments for groups outside allowlist', async () => {
|
||||||
|
const adapter = new DiscordAdapter({
|
||||||
|
token: 'token',
|
||||||
|
attachmentsDir: '/tmp/attachments',
|
||||||
|
groups: {
|
||||||
|
'channel-2': { mode: 'open' },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const onMessage = vi.fn().mockResolvedValue(undefined);
|
||||||
|
adapter.onMessage = onMessage;
|
||||||
|
|
||||||
|
const fetchSpy = vi.spyOn(globalThis, 'fetch');
|
||||||
|
|
||||||
|
await adapter.start();
|
||||||
|
const client = discordMock.getLatestClient();
|
||||||
|
expect(client).toBeTruthy();
|
||||||
|
|
||||||
|
const message = makeMessage({
|
||||||
|
content: 'hello',
|
||||||
|
isThread: false,
|
||||||
|
channelId: 'channel-1',
|
||||||
|
});
|
||||||
|
message.attachments = {
|
||||||
|
find: () => undefined,
|
||||||
|
values: () => [{
|
||||||
|
id: 'att-1',
|
||||||
|
name: 'image.png',
|
||||||
|
size: 123,
|
||||||
|
url: 'https://cdn.example.com/image.png',
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
|
await client!.emit('messageCreate', message);
|
||||||
|
|
||||||
|
expect(fetchSpy).not.toHaveBeenCalled();
|
||||||
|
expect(onMessage).not.toHaveBeenCalled();
|
||||||
|
await adapter.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not fetch voice attachment audio for groups outside allowlist', async () => {
|
||||||
|
const adapter = new DiscordAdapter({
|
||||||
|
token: 'token',
|
||||||
|
groups: {
|
||||||
|
'channel-2': { mode: 'open' },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const onMessage = vi.fn().mockResolvedValue(undefined);
|
||||||
|
adapter.onMessage = onMessage;
|
||||||
|
|
||||||
|
const fetchSpy = vi.spyOn(globalThis, 'fetch');
|
||||||
|
|
||||||
|
await adapter.start();
|
||||||
|
const client = discordMock.getLatestClient();
|
||||||
|
expect(client).toBeTruthy();
|
||||||
|
|
||||||
|
const message = makeMessage({
|
||||||
|
content: '',
|
||||||
|
isThread: false,
|
||||||
|
channelId: 'channel-1',
|
||||||
|
});
|
||||||
|
message.attachments = {
|
||||||
|
find: () => ({
|
||||||
|
contentType: 'audio/ogg',
|
||||||
|
name: 'voice.ogg',
|
||||||
|
url: 'https://cdn.example.com/voice.ogg',
|
||||||
|
}),
|
||||||
|
values: () => [{
|
||||||
|
id: 'att-audio-1',
|
||||||
|
contentType: 'audio/ogg',
|
||||||
|
name: 'voice.ogg',
|
||||||
|
size: 321,
|
||||||
|
url: 'https://cdn.example.com/voice.ogg',
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
|
await client!.emit('messageCreate', message);
|
||||||
|
|
||||||
|
expect(fetchSpy).not.toHaveBeenCalled();
|
||||||
|
expect(onMessage).not.toHaveBeenCalled();
|
||||||
|
expect(message.reply).not.toHaveBeenCalled();
|
||||||
|
await adapter.stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('blocks managed slash commands for groups outside allowlist', async () => {
|
||||||
|
const adapter = new DiscordAdapter({
|
||||||
|
token: 'token',
|
||||||
|
groups: {
|
||||||
|
'channel-2': { mode: 'open' },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const onCommand = vi.fn().mockResolvedValue('ok');
|
||||||
|
adapter.onCommand = onCommand;
|
||||||
|
|
||||||
|
await adapter.start();
|
||||||
|
const client = discordMock.getLatestClient();
|
||||||
|
expect(client).toBeTruthy();
|
||||||
|
|
||||||
|
const message = makeMessage({
|
||||||
|
content: '/status',
|
||||||
|
isThread: false,
|
||||||
|
channelId: 'channel-1',
|
||||||
|
});
|
||||||
|
|
||||||
|
await client!.emit('messageCreate', message);
|
||||||
|
|
||||||
|
expect(onCommand).not.toHaveBeenCalled();
|
||||||
|
expect(message.channel.send).not.toHaveBeenCalled();
|
||||||
|
await adapter.stop();
|
||||||
|
});
|
||||||
|
|
||||||
it('blocks top-level slash commands when threadMode is thread-only', async () => {
|
it('blocks top-level slash commands when threadMode is thread-only', async () => {
|
||||||
const adapter = new DiscordAdapter({
|
const adapter = new DiscordAdapter({
|
||||||
token: 'token',
|
token: 'token',
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import { basename } from 'node:path';
|
|||||||
import { createLogger } from '../logger.js';
|
import { createLogger } from '../logger.js';
|
||||||
|
|
||||||
const log = createLogger('Discord');
|
const log = createLogger('Discord');
|
||||||
|
const DISCORD_ATTACHMENT_DOWNLOAD_TIMEOUT_MS = 15000;
|
||||||
// Dynamic import to avoid requiring Discord deps if not used
|
// Dynamic import to avoid requiring Discord deps if not used
|
||||||
let Client: typeof import('discord.js').Client;
|
let Client: typeof import('discord.js').Client;
|
||||||
let GatewayIntentBits: typeof import('discord.js').GatewayIntentBits;
|
let GatewayIntentBits: typeof import('discord.js').GatewayIntentBits;
|
||||||
@@ -243,36 +244,6 @@ Ask the bot owner to approve with:
|
|||||||
let content = (message.content || '').trim();
|
let content = (message.content || '').trim();
|
||||||
const userId = message.author?.id;
|
const userId = message.author?.id;
|
||||||
if (!userId) return;
|
if (!userId) return;
|
||||||
|
|
||||||
// Handle audio attachments
|
|
||||||
const audioAttachment = message.attachments.find(a => a.contentType?.startsWith('audio/'));
|
|
||||||
if (audioAttachment?.url) {
|
|
||||||
try {
|
|
||||||
const { isTranscriptionConfigured } = await import('../transcription/index.js');
|
|
||||||
if (!isTranscriptionConfigured()) {
|
|
||||||
await message.reply('Voice messages require a transcription API key. See: https://github.com/letta-ai/lettabot#voice');
|
|
||||||
} else {
|
|
||||||
// Download audio
|
|
||||||
const response = await fetch(audioAttachment.url);
|
|
||||||
const buffer = Buffer.from(await response.arrayBuffer());
|
|
||||||
|
|
||||||
const { transcribeAudio } = await import('../transcription/index.js');
|
|
||||||
const ext = audioAttachment.contentType?.split('/')[1] || 'mp3';
|
|
||||||
const result = await transcribeAudio(buffer, audioAttachment.name || `audio.${ext}`);
|
|
||||||
|
|
||||||
if (result.success && result.text) {
|
|
||||||
log.info(`Transcribed audio: "${result.text.slice(0, 50)}..."`);
|
|
||||||
content = (content ? content + '\n' : '') + `[Voice message]: ${result.text}`;
|
|
||||||
} else {
|
|
||||||
log.error(`Transcription failed: ${result.error}`);
|
|
||||||
content = (content ? content + '\n' : '') + `[Voice message - transcription failed: ${result.error}]`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
log.error('Error transcribing audio:', error);
|
|
||||||
content = (content ? content + '\n' : '') + `[Voice message - error: ${error instanceof Error ? error.message : 'unknown error'}]`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bypass pairing for guild (group) messages
|
// Bypass pairing for guild (group) messages
|
||||||
if (!message.guildId) {
|
if (!message.guildId) {
|
||||||
@@ -306,9 +277,6 @@ Ask the bot owner to approve with:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const attachments = await this.collectAttachments(message.attachments, message.channel.id);
|
|
||||||
if (!content && attachments.length === 0) return;
|
|
||||||
|
|
||||||
if (content.startsWith('/')) {
|
if (content.startsWith('/')) {
|
||||||
const parts = content.slice(1).split(/\s+/);
|
const parts = content.slice(1).split(/\s+/);
|
||||||
const command = parts[0]?.toLowerCase();
|
const command = parts[0]?.toLowerCase();
|
||||||
@@ -324,6 +292,23 @@ Ask the bot owner to approve with:
|
|||||||
|
|
||||||
// Unknown commands (or managed commands without onCommand) fall through to agent processing.
|
// Unknown commands (or managed commands without onCommand) fall through to agent processing.
|
||||||
if (isHelpCommand || (isManagedCommand && this.onCommand)) {
|
if (isHelpCommand || (isManagedCommand && this.onCommand)) {
|
||||||
|
if (isGroup && this.config.groups && !isHelpCommand) {
|
||||||
|
if (!isGroupAllowed(this.config.groups, keys)) {
|
||||||
|
log.info(`Group ${chatId} not in allowlist, ignoring command`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!isGroupUserAllowed(this.config.groups, keys, userId)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const mode = resolveGroupMode(this.config.groups, keys, 'open');
|
||||||
|
if (mode === 'disabled') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (mode === 'mention-only' && !wasMentioned) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let commandChatId = message.channel.id;
|
let commandChatId = message.channel.id;
|
||||||
let commandSendTarget: { send: (content: string) => Promise<unknown> } | null =
|
let commandSendTarget: { send: (content: string) => Promise<unknown> } | null =
|
||||||
message.channel.isTextBased() && 'send' in message.channel
|
message.channel.isTextBased() && 'send' in message.channel
|
||||||
@@ -434,6 +419,37 @@ Ask the bot owner to approve with:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const audioAttachment = message.attachments.find((a) => a.contentType?.startsWith('audio/'));
|
||||||
|
if (audioAttachment?.url) {
|
||||||
|
try {
|
||||||
|
const { isTranscriptionConfigured } = await import('../transcription/index.js');
|
||||||
|
if (!isTranscriptionConfigured()) {
|
||||||
|
await message.reply('Voice messages require a transcription API key. See: https://github.com/letta-ai/lettabot#voice');
|
||||||
|
} else {
|
||||||
|
const response = await fetch(audioAttachment.url);
|
||||||
|
const buffer = Buffer.from(await response.arrayBuffer());
|
||||||
|
|
||||||
|
const { transcribeAudio } = await import('../transcription/index.js');
|
||||||
|
const ext = audioAttachment.contentType?.split('/')[1] || 'mp3';
|
||||||
|
const result = await transcribeAudio(buffer, audioAttachment.name || `audio.${ext}`);
|
||||||
|
|
||||||
|
if (result.success && result.text) {
|
||||||
|
log.info(`Transcribed audio: "${result.text.slice(0, 50)}..."`);
|
||||||
|
content = (content ? content + '\n' : '') + `[Voice message]: ${result.text}`;
|
||||||
|
} else {
|
||||||
|
log.error(`Transcription failed: ${result.error}`);
|
||||||
|
content = (content ? content + '\n' : '') + `[Voice message - transcription failed: ${result.error}]`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log.error('Error transcribing audio:', error);
|
||||||
|
content = (content ? content + '\n' : '') + `[Voice message - error: ${error instanceof Error ? error.message : 'unknown error'}]`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const attachments = await this.collectAttachments(message.attachments, message.channel.id);
|
||||||
|
if (!content && attachments.length === 0) return;
|
||||||
|
|
||||||
await this.onMessage({
|
await this.onMessage({
|
||||||
channel: 'discord',
|
channel: 'discord',
|
||||||
chatId: effectiveChatId,
|
chatId: effectiveChatId,
|
||||||
@@ -720,7 +736,9 @@ Ask the bot owner to approve with:
|
|||||||
}
|
}
|
||||||
const target = buildAttachmentPath(this.attachmentsDir, 'discord', channelId, name);
|
const target = buildAttachmentPath(this.attachmentsDir, 'discord', channelId, name);
|
||||||
try {
|
try {
|
||||||
await downloadToFile(attachment.url, target);
|
await downloadToFile(attachment.url, target, {
|
||||||
|
timeoutMs: DISCORD_ATTACHMENT_DOWNLOAD_TIMEOUT_MS,
|
||||||
|
});
|
||||||
entry.localPath = target;
|
entry.localPath = target;
|
||||||
log.info(`Attachment saved to ${target}`);
|
log.info(`Attachment saved to ${target}`);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@@ -512,7 +512,9 @@ async function maybeDownloadSlackFile(
|
|||||||
}
|
}
|
||||||
const target = buildAttachmentPath(attachmentsDir, 'slack', channelId, name);
|
const target = buildAttachmentPath(attachmentsDir, 'slack', channelId, name);
|
||||||
try {
|
try {
|
||||||
await downloadToFile(url, target, { Authorization: `Bearer ${token}` });
|
await downloadToFile(url, target, {
|
||||||
|
headers: { Authorization: `Bearer ${token}` },
|
||||||
|
});
|
||||||
attachment.localPath = target;
|
attachment.localPath = target;
|
||||||
log.info(`Attachment saved to ${target}`);
|
log.info(`Attachment saved to ${target}`);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
Reference in New Issue
Block a user