diff --git a/src/channels/discord-adapter.test.ts b/src/channels/discord-adapter.test.ts new file mode 100644 index 0000000..eadc272 --- /dev/null +++ b/src/channels/discord-adapter.test.ts @@ -0,0 +1,262 @@ +import { afterEach, describe, expect, it, vi } from 'vitest'; + +const discordMock = vi.hoisted(() => { + type Handler = (...args: unknown[]) => unknown | Promise; + + class MockDiscordClient { + private handlers = new Map(); + user = { id: 'bot-self', tag: 'bot#0001' }; + channels = { fetch: vi.fn() }; + destroy = vi.fn(); + + once(event: string, handler: Handler): this { + return this.on(event, handler); + } + + on(event: string, handler: Handler): this { + const existing = this.handlers.get(event) || []; + existing.push(handler); + this.handlers.set(event, existing); + return this; + } + + async login(): Promise { + await this.emit('clientReady'); + return 'ok'; + } + + async emit(event: string, ...args: unknown[]): Promise { + const handlers = this.handlers.get(event) || []; + for (const handler of handlers) { + await handler(...args); + } + } + } + + let latestClient: MockDiscordClient | null = null; + class Client extends MockDiscordClient { + constructor(_options: unknown) { + super(); + latestClient = this; + } + } + + return { + Client, + getLatestClient: () => latestClient, + }; +}); + +vi.mock('discord.js', () => ({ + Client: discordMock.Client, + GatewayIntentBits: { + Guilds: 1, + GuildMessages: 2, + GuildMessageReactions: 3, + MessageContent: 4, + DirectMessages: 5, + DirectMessageReactions: 6, + }, + Partials: { + Channel: 1, + Message: 2, + Reaction: 3, + User: 4, + }, +})); + +const { DiscordAdapter } = await import('./discord.js'); + +function makeMessage(params: { + content: string; + isThread: boolean; + channelId: string; + parentId?: string; +}): { + id: string; + content: string; + guildId: string; + channel: { + id: string; + parentId?: string; + name: string; + send: ReturnType; + isThread: () => boolean; + isTextBased: () => boolean; + }; + author: { + id: string; + bot: boolean; + username: string; + globalName: string; + send: ReturnType; + }; + member: { displayName: string }; + mentions: { has: () => boolean }; + attachments: { find: () => undefined; values: () => unknown[] }; + createdAt: Date; + reply: ReturnType; + startThread: ReturnType; +} { + return { + id: 'msg-1', + content: params.content, + guildId: 'guild-1', + channel: { + id: params.channelId, + parentId: params.parentId, + name: 'general', + send: vi.fn().mockResolvedValue({ id: 'sent-1' }), + isThread: () => params.isThread, + isTextBased: () => true, + }, + author: { + id: 'user-1', + bot: false, + username: 'alice', + globalName: 'Alice', + send: vi.fn().mockResolvedValue(undefined), + }, + member: { displayName: 'Alice' }, + mentions: { has: () => false }, + attachments: { + find: () => undefined, + values: () => [], + }, + createdAt: new Date(), + reply: vi.fn().mockResolvedValue(undefined), + startThread: vi.fn().mockResolvedValue({ id: 'thread-created', name: 'new thread' }), + }; +} + +describe('DiscordAdapter command gating', () => { + afterEach(async () => { + vi.clearAllMocks(); + }); + + it('blocks top-level slash commands when threadMode is thread-only', async () => { + const adapter = new DiscordAdapter({ + token: 'token', + groups: { + 'channel-1': { mode: 'open', threadMode: 'thread-only' }, + }, + }); + const onCommand = vi.fn().mockResolvedValue('ok'); + adapter.onCommand = onCommand; + + await adapter.start(); + const client = discordMock.getLatestClient(); + expect(client).toBeTruthy(); + + const message = makeMessage({ + content: '/status', + isThread: false, + channelId: 'channel-1', + }); + + await client!.emit('messageCreate', message); + + expect(onCommand).not.toHaveBeenCalled(); + expect(message.channel.send).not.toHaveBeenCalled(); + await adapter.stop(); + }); + + it('allows slash commands inside threads in thread-only mode', async () => { + const adapter = new DiscordAdapter({ + token: 'token', + groups: { + 'channel-1': { mode: 'open', threadMode: 'thread-only' }, + }, + }); + const onCommand = vi.fn().mockResolvedValue('ok'); + adapter.onCommand = onCommand; + + await adapter.start(); + const client = discordMock.getLatestClient(); + expect(client).toBeTruthy(); + + const message = makeMessage({ + content: '/status', + isThread: true, + channelId: 'thread-1', + parentId: 'channel-1', + }); + + await client!.emit('messageCreate', message); + + expect(onCommand).toHaveBeenCalledTimes(1); + expect(onCommand).toHaveBeenCalledWith('status', 'thread-1', undefined); + expect(message.channel.send).toHaveBeenCalledWith('ok'); + await adapter.stop(); + }); + + it('redirects mentioned top-level commands into an auto-created thread', async () => { + const adapter = new DiscordAdapter({ + token: 'token', + groups: { + 'channel-1': { mode: 'open', threadMode: 'thread-only', autoCreateThreadOnMention: true }, + }, + }); + const onCommand = vi.fn().mockResolvedValue('ok'); + adapter.onCommand = onCommand; + + await adapter.start(); + const client = discordMock.getLatestClient(); + expect(client).toBeTruthy(); + + const threadSend = vi.fn().mockResolvedValue({ id: 'thread-msg-1' }); + (client!.channels.fetch as ReturnType).mockResolvedValue({ + id: 'thread-created', + isTextBased: () => true, + send: threadSend, + }); + + const message = makeMessage({ + content: '/status', + isThread: false, + channelId: 'channel-1', + }); + message.mentions = { has: () => true }; + + await client!.emit('messageCreate', message); + + expect(message.startThread).toHaveBeenCalledTimes(1); + expect(onCommand).toHaveBeenCalledTimes(1); + expect(onCommand).toHaveBeenCalledWith('status', 'thread-created', undefined); + expect(threadSend).toHaveBeenCalledWith('ok'); + expect(message.channel.send).not.toHaveBeenCalled(); + await adapter.stop(); + }); + + it('creates one thread when unknown slash commands fall through to agent handling', async () => { + const adapter = new DiscordAdapter({ + token: 'token', + groups: { + 'channel-1': { mode: 'open', threadMode: 'thread-only', autoCreateThreadOnMention: true }, + }, + }); + const onMessage = vi.fn().mockResolvedValue(undefined); + adapter.onMessage = onMessage; + + await adapter.start(); + const client = discordMock.getLatestClient(); + expect(client).toBeTruthy(); + + const message = makeMessage({ + content: '/unknown', + isThread: false, + channelId: 'channel-1', + }); + message.mentions = { has: () => true }; + + await client!.emit('messageCreate', message); + + expect(message.startThread).toHaveBeenCalledTimes(1); + expect(onMessage).toHaveBeenCalledTimes(1); + expect(onMessage).toHaveBeenCalledWith(expect.objectContaining({ + chatId: 'thread-created', + text: '/unknown', + })); + await adapter.stop(); + }); +}); diff --git a/src/channels/discord.ts b/src/channels/discord.ts index ffafd52..5efb520 100644 --- a/src/channels/discord.ts +++ b/src/channels/discord.ts @@ -229,6 +229,7 @@ Ask the bot owner to approve with: serverId: message.guildId, }); const selfUserId = this.client?.user?.id; + const wasMentioned = isGroup && !!this.client?.user && message.mentions.has(this.client.user); if (!shouldProcessDiscordBotMessage({ isFromBot, @@ -312,15 +313,62 @@ Ask the bot owner to approve with: const parts = content.slice(1).split(/\s+/); const command = parts[0]?.toLowerCase(); const cmdArgs = parts.slice(1).join(' ') || undefined; - if (command === 'help' || command === 'start') { - await message.channel.send(HELP_TEXT); - return; - } - if (this.onCommand) { - if (command === 'status' || command === 'reset' || command === 'heartbeat' || command === 'cancel' || command === 'model' || command === 'setconv') { - const result = await this.onCommand(command, message.channel.id, cmdArgs); + const isHelpCommand = command === 'help' || command === 'start'; + const isManagedCommand = + command === 'status' || + command === 'reset' || + command === 'heartbeat' || + command === 'cancel' || + command === 'model' || + command === 'setconv'; + + // Unknown commands (or managed commands without onCommand) fall through to agent processing. + if (isHelpCommand || (isManagedCommand && this.onCommand)) { + let commandChatId = message.channel.id; + let commandSendTarget: { send: (content: string) => Promise } | null = + message.channel.isTextBased() && 'send' in message.channel + ? (message.channel as { send: (content: string) => Promise }) + : null; + + if (isGroup && this.config.groups) { + const threadMode = resolveDiscordThreadMode(this.config.groups, keys); + if (threadMode === 'thread-only' && !isThreadMessage) { + const shouldCreateThread = + wasMentioned && resolveDiscordAutoCreateThreadOnMention(this.config.groups, keys); + if (!shouldCreateThread) { + return; + } + + // Keep command behavior aligned with normal message gating in thread-only mode. + const createdThread = await this.createThreadForMention(message, content); + if (!createdThread) { + return; + } + + if (!this.client) { + return; + } + const threadChannel = await this.client.channels.fetch(createdThread.id); + if (!threadChannel || !threadChannel.isTextBased() || !('send' in threadChannel)) { + return; + } + + commandChatId = createdThread.id; + commandSendTarget = threadChannel as { send: (content: string) => Promise }; + } + } + + if (isHelpCommand) { + if (!commandSendTarget) return; + await commandSendTarget.send(HELP_TEXT); + return; + } + + if (this.onCommand && isManagedCommand) { + const result = await this.onCommand(command, commandChatId, cmdArgs); if (result) { - await message.channel.send(result); + if (!commandSendTarget) return; + await commandSendTarget.send(result); } return; } @@ -330,7 +378,6 @@ Ask the bot owner to approve with: if (this.onMessage) { const groupName = isGroup && 'name' in message.channel ? message.channel.name : undefined; const displayName = message.member?.displayName || message.author.globalName || message.author.username; - const wasMentioned = isGroup && !!this.client?.user && message.mentions.has(this.client.user); let isListeningMode = false; let effectiveChatId = message.channel.id; let effectiveGroupName = groupName;