From ee28095ebc6ee20aaa7cfc495cd8aeeee0f75ce9 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Thu, 5 Feb 2026 17:55:00 -0800 Subject: [PATCH] feat: add prompt based hooks (#795) Co-authored-by: Letta --- src/cli/App.tsx | 12 +- src/cli/components/HooksManager.tsx | 42 ++- src/cli/helpers/contextTracker.ts | 2 +- src/hooks/executor.ts | 63 +++- src/hooks/index.ts | 2 + src/hooks/loader.ts | 33 +- src/hooks/prompt-executor.ts | 260 +++++++++++++++ src/hooks/types.ts | 86 ++++- src/tests/hooks/loader.test.ts | 66 +++- src/tests/hooks/prompt-executor.test.ts | 425 ++++++++++++++++++++++++ src/tests/settings-manager.test.ts | 18 +- 11 files changed, 967 insertions(+), 42 deletions(-) create mode 100644 src/hooks/prompt-executor.ts create mode 100644 src/tests/hooks/prompt-executor.test.ts diff --git a/src/cli/App.tsx b/src/cli/App.tsx index 47748ba..ba40d14 100644 --- a/src/cli/App.tsx +++ b/src/cli/App.tsx @@ -3336,7 +3336,7 @@ export default function App({ lastDequeuedMessageRef.current = null; // Clear - message was processed successfully lastSentInputRef.current = null; // Clear - no recovery needed - // Get last assistant message and reasoning for Stop hook + // Get last assistant message, user message, and reasoning for Stop hook const lastAssistant = Array.from( buffersRef.current.byId.values(), ).findLast((item) => item.kind === "assistant" && "text" in item); @@ -3344,6 +3344,11 @@ export default function App({ lastAssistant && "text" in lastAssistant ? lastAssistant.text : undefined; + const firstUser = Array.from(buffersRef.current.byId.values()).find( + (item) => item.kind === "user" && "text" in item, + ); + const userMessage = + firstUser && "text" in firstUser ? firstUser.text : undefined; const precedingReasoning = buffersRef.current.lastReasoning; buffersRef.current.lastReasoning = undefined; // Clear after use @@ -3357,6 +3362,7 @@ export default function App({ undefined, // workingDirectory (uses default) precedingReasoning, assistantMessage, + userMessage, ); // If hook blocked (exit 2), inject stderr feedback and continue conversation @@ -3373,9 +3379,7 @@ export default function App({ buffersRef.current.byId.set(statusId, { kind: "status", id: statusId, - lines: [ - "Stop hook encountered blocking error, continuing loop with stderr feedback.", - ], + lines: ["Stop hook blocked, continuing conversation."], }); buffersRef.current.order.push(statusId); refreshDerived(); diff --git a/src/cli/components/HooksManager.tsx b/src/cli/components/HooksManager.tsx index a1d0d5f..5ddae7e 100644 --- a/src/cli/components/HooksManager.tsx +++ b/src/cli/components/HooksManager.tsx @@ -4,8 +4,11 @@ import { Box, useInput } from "ink"; import { memo, useCallback, useEffect, useRef, useState } from "react"; import { + type HookCommand, type HookEvent, type HookMatcher, + isCommandHook, + isPromptHook, isToolEvent, type SimpleHookEvent, type SimpleHookMatcher, @@ -39,6 +42,21 @@ const BOX_BOTTOM_RIGHT = "╯"; const BOX_HORIZONTAL = "─"; const BOX_VERTICAL = "│"; +/** + * Get a display label for a hook (command or prompt). + * For prompt hooks, returns just the prompt text (without prefix). + */ +function getHookDisplayLabel(hook: HookCommand | undefined): string { + if (!hook) return ""; + if (isCommandHook(hook)) { + return hook.command; + } + if (isPromptHook(hook)) { + return `${hook.prompt.slice(0, 40)}${hook.prompt.length > 40 ? "..." : ""}`; + } + return ""; +} + interface HooksManagerProps { onClose: () => void; agentId?: string; @@ -533,10 +551,12 @@ export const HooksManager = memo(function HooksManager({ const matcherPattern = isToolMatcher ? (hook as HookMatcherWithSource).matcher || "*" : null; - // Both types have hooks array - const command = "hooks" in hook ? hook.hooks[0]?.command || "" : ""; + // Both types have hooks array - get display label for first hook + const firstHook = "hooks" in hook ? hook.hooks[0] : undefined; + const command = getHookDisplayLabel(firstHook); const truncatedCommand = command.length > 50 ? `${command.slice(0, 47)}...` : command; + const isPrompt = firstHook ? isPromptHook(firstHook) : false; return ( @@ -549,6 +569,7 @@ export const HooksManager = memo(function HooksManager({ ) : ( )} + {isPrompt && } {truncatedCommand} ); @@ -691,8 +712,10 @@ export const HooksManager = memo(function HooksManager({ const matcherPattern = isToolMatcher ? (hook as HookMatcherWithSource).matcher || "*" : null; - // Both types have hooks array - const command = hook && "hooks" in hook ? hook.hooks[0]?.command : ""; + // Both types have hooks array - get display label for first hook + const firstHook = hook && "hooks" in hook ? hook.hooks[0] : undefined; + const command = getHookDisplayLabel(firstHook); + const isPrompt = firstHook ? isPromptHook(firstHook) : false; return ( @@ -702,7 +725,16 @@ export const HooksManager = memo(function HooksManager({ {matcherPattern !== null && Matcher: {matcherPattern}} - Command: {command} + + {isPrompt ? ( + <> + Hook: + {command} + + ) : ( + <>Command: {command} + )} + Source: {hook ? getSourceLabel(hook.source) : ""} diff --git a/src/cli/helpers/contextTracker.ts b/src/cli/helpers/contextTracker.ts index 560b369..dd83f36 100644 --- a/src/cli/helpers/contextTracker.ts +++ b/src/cli/helpers/contextTracker.ts @@ -22,7 +22,7 @@ export function createContextTracker(): ContextTracker { return { lastContextTokens: 0, contextTokensHistory: [], - currentTurnId: 0, + currentTurnId: 0, // simple in-memory counter for now pendingCompaction: false, }; } diff --git a/src/hooks/executor.ts b/src/hooks/executor.ts index 4997169..cb3dfbe 100644 --- a/src/hooks/executor.ts +++ b/src/hooks/executor.ts @@ -4,17 +4,35 @@ import { type ChildProcess, spawn } from "node:child_process"; import { buildShellLaunchers } from "../tools/impl/shellLaunchers"; +import { executePromptHook } from "./prompt-executor"; import { + type CommandHookConfig, type HookCommand, type HookExecutionResult, HookExitCode, type HookInput, type HookResult, + isCommandHook, + isPromptHook, } from "./types"; /** Default timeout for hook execution (60 seconds) */ const DEFAULT_TIMEOUT_MS = 60000; +/** + * Get a display identifier for a hook (for logging and feedback) + */ +function getHookIdentifier(hook: HookCommand): string { + if (isCommandHook(hook)) { + return hook.command; + } + if (isPromptHook(hook)) { + // Use first 50 chars of prompt as identifier + return `prompt:${hook.prompt.slice(0, 50)}${hook.prompt.length > 50 ? "..." : ""}`; + } + return "unknown"; +} + /** * Try to spawn a hook command with a specific launcher * Returns the child process or throws an error @@ -50,13 +68,45 @@ function trySpawnWithLauncher( } /** - * Execute a single hook command with JSON input via stdin - * Uses cross-platform shell launchers with fallback support + * Execute a single hook with JSON input + * Dispatches to appropriate executor based on hook type: + * - "command": executes shell command with JSON via stdin + * - "prompt": sends to LLM for evaluation */ export async function executeHookCommand( hook: HookCommand, input: HookInput, workingDirectory: string = process.cwd(), +): Promise { + // Dispatch based on hook type + if (isPromptHook(hook)) { + return executePromptHook(hook, input, workingDirectory); + } + + // Default to command hook execution + if (isCommandHook(hook)) { + return executeCommandHook(hook, input, workingDirectory); + } + + // Unknown hook type + return { + exitCode: HookExitCode.ERROR, + stdout: "", + stderr: "", + timedOut: false, + durationMs: 0, + error: `Unknown hook type: ${(hook as HookCommand).type}`, + }; +} + +/** + * Execute a command hook with JSON input via stdin + * Uses cross-platform shell launchers with fallback support + */ +export async function executeCommandHook( + hook: CommandHookConfig, + input: HookInput, + workingDirectory: string = process.cwd(), ): Promise { const startTime = Date.now(); const timeout = hook.timeout ?? DEFAULT_TIMEOUT_MS; @@ -307,11 +357,10 @@ export async function executeHooks( } // Collect feedback from stderr when hook blocks - // Format: [command]: {stderr} per spec if (result.exitCode === HookExitCode.BLOCK) { blocked = true; if (result.stderr) { - feedback.push(`[${hook.command}]: ${result.stderr}`); + feedback.push(`[${getHookIdentifier(hook)}]: ${result.stderr}`); } // Stop processing more hooks after a block break; @@ -358,7 +407,7 @@ export async function executeHooksParallel( const hook = hooks[i]; if (!result || !hook) continue; - // For exit 0, try to parse JSON for additionalContext (matching Claude Code behavior) + // For exit 0, try to parse JSON for additionalContext if (result.exitCode === HookExitCode.ALLOW && result.stdout?.trim()) { try { const json = JSON.parse(result.stdout.trim()); @@ -373,11 +422,11 @@ export async function executeHooksParallel( } } - // Format: [command]: {stderr} per spec + // Collect feedback from stderr when hook blocks if (result.exitCode === HookExitCode.BLOCK) { blocked = true; if (result.stderr) { - feedback.push(`[${hook.command}]: ${result.stderr}`); + feedback.push(`[${getHookIdentifier(hook)}]: ${result.stderr}`); } } if (result.exitCode === HookExitCode.ERROR) { diff --git a/src/hooks/index.ts b/src/hooks/index.ts index 4de286b..af1c582 100644 --- a/src/hooks/index.ts +++ b/src/hooks/index.ts @@ -266,6 +266,7 @@ export async function runStopHooks( workingDirectory: string = process.cwd(), precedingReasoning?: string, assistantMessage?: string, + userMessage?: string, ): Promise { const hooks = await getHooksForEvent("Stop", undefined, workingDirectory); if (hooks.length === 0) { @@ -280,6 +281,7 @@ export async function runStopHooks( tool_call_count: toolCallCount, preceding_reasoning: precedingReasoning, assistant_message: assistantMessage, + user_message: userMessage, }; // Run sequentially - Stop can block diff --git a/src/hooks/loader.ts b/src/hooks/loader.ts index cbbd58b..f22f056 100644 --- a/src/hooks/loader.ts +++ b/src/hooks/loader.ts @@ -8,9 +8,11 @@ import { type HookEvent, type HookMatcher, type HooksConfig, + isPromptHook, isToolEvent, type SimpleHookEvent, type SimpleHookMatcher, + supportsPromptHooks, type ToolHookEvent, } from "./types"; @@ -177,6 +179,33 @@ export function matchesTool(pattern: string, toolName: string): boolean { } } +/** + * Filter hooks, removing prompt hooks from unsupported events with a warning + */ +function filterHooksForEvent( + hooks: HookCommand[], + event: HookEvent, +): HookCommand[] { + const filtered: HookCommand[] = []; + const promptHooksSupported = supportsPromptHooks(event); + + for (const hook of hooks) { + if (isPromptHook(hook)) { + if (!promptHooksSupported) { + // Warn about unsupported prompt hook + console.warn( + `\x1b[33m[hooks] Warning: Prompt hooks are not supported for the ${event} event. ` + + `Ignoring prompt hook.\x1b[0m`, + ); + continue; + } + } + filtered.push(hook); + } + + return filtered; +} + /** * Get all hooks that match a specific event and tool name */ @@ -200,7 +229,7 @@ export function getMatchingHooks( hooks.push(...matcher.hooks); } } - return hooks; + return filterHooksForEvent(hooks, event); } else { // Simple events use SimpleHookMatcher[] - extract hooks from each matcher const matchers = config[event as SimpleHookEvent] as @@ -214,7 +243,7 @@ export function getMatchingHooks( for (const matcher of matchers) { hooks.push(...matcher.hooks); } - return hooks; + return filterHooksForEvent(hooks, event); } } diff --git a/src/hooks/prompt-executor.ts b/src/hooks/prompt-executor.ts new file mode 100644 index 0000000..3492c8b --- /dev/null +++ b/src/hooks/prompt-executor.ts @@ -0,0 +1,260 @@ +import { getClient } from "../agent/client"; +import { getCurrentAgentId } from "../agent/context"; +import { + HookExitCode, + type HookInput, + type HookResult, + PROMPT_ARGUMENTS_PLACEHOLDER, + type PromptHookConfig, + type PromptHookResponse, +} from "./types"; + +/** Default timeout for prompt hook execution (30 seconds) */ +const DEFAULT_PROMPT_TIMEOUT_MS = 30000; + +/** + * System prompt for the LLM to evaluate hooks. + * Instructs the model to return a JSON decision per Claude Code spec. + */ +const PROMPT_HOOK_SYSTEM = `You are a hook evaluator for a coding assistant. Your job is to evaluate whether an action should be allowed or blocked based on the provided context and criteria. + +You will receive: +1. Hook input JSON containing context about the action (event type, tool info, etc.) +2. A user-defined prompt with evaluation criteria + +You must respond with ONLY a valid JSON object (no markdown, no explanation) with the following fields: +- "ok": true to allow the action, false to prevent it +- "reason": Required when ok is false. Explanation for your decision. + +Example responses: +- To allow: {"ok": true} +- To block: {"ok": false, "reason": "This action violates the security policy"} + +Respond with JSON only. No markdown code blocks. No explanation outside the JSON.`; + +/** + * Build the prompt to send to the LLM, replacing $ARGUMENTS with hook input. + * If $ARGUMENTS is not present in the prompt, append the input JSON. + */ +function buildPrompt(hookPrompt: string, input: HookInput): string { + const inputJson = JSON.stringify(input, null, 2); + + // If $ARGUMENTS placeholder exists, replace all occurrences + if (hookPrompt.includes(PROMPT_ARGUMENTS_PLACEHOLDER)) { + return hookPrompt.replaceAll(PROMPT_ARGUMENTS_PLACEHOLDER, inputJson); + } + + // Otherwise, append input JSON to the prompt + return `${hookPrompt}\n\nHook input:\n${inputJson}`; +} + +/** + * Parse the LLM response as JSON, handling potential formatting issues + */ +function parsePromptResponse(response: string): PromptHookResponse { + // Try to extract JSON from the response + let jsonStr = response.trim(); + + // Handle markdown code blocks + const jsonMatch = jsonStr.match(/```(?:json)?\s*([\s\S]*?)\s*```/); + if (jsonMatch) { + jsonStr = jsonMatch[1] || jsonStr; + } + + // Try to find JSON object in the response (non-greedy to avoid spanning multiple objects) + const objectMatch = jsonStr.match(/\{[\s\S]*?\}/); + if (objectMatch) { + jsonStr = objectMatch[0]; + } + + try { + const parsed = JSON.parse(jsonStr); + + // Validate the response structure - ok must be a boolean + if (typeof parsed?.ok !== "boolean") { + throw new Error( + `Invalid prompt hook response: "ok" must be a boolean, got ${typeof parsed?.ok}`, + ); + } + + return parsed as PromptHookResponse; + } catch (e) { + // Re-throw validation errors as-is + if (e instanceof Error && e.message.startsWith("Invalid prompt hook")) { + throw e; + } + // If parsing fails, treat as error + throw new Error(`Failed to parse LLM response as JSON: ${response}`); + } +} + +/** + * Convert PromptHookResponse to HookResult + */ +function responseToHookResult( + response: PromptHookResponse, + durationMs: number, +): HookResult { + // ok: true allows the action, ok: false (or missing) blocks it + const shouldBlock = response.ok !== true; + + return { + exitCode: shouldBlock ? HookExitCode.BLOCK : HookExitCode.ALLOW, + stdout: JSON.stringify(response), + stderr: shouldBlock ? response.reason || "" : "", + timedOut: false, + durationMs, + }; +} + +/** + * Extract agent_id from hook input, falling back to the global agent context. + */ +function getAgentId(input: HookInput): string | undefined { + // 1. Check hook input directly (most hook event types include agent_id) + if ("agent_id" in input && input.agent_id) { + return input.agent_id; + } + // 2. Fall back to the global agent context (set during session) + try { + return getCurrentAgentId(); + } catch { + // Context not available + } + // 3. Last resort: env var (set by shell env for subprocesses) + return process.env.LETTA_AGENT_ID; +} + +/** + * JSON schema for structured prompt hook responses. + * Forces the LLM to return {ok: boolean, reason?: string} via tool calling. + */ +const PROMPT_HOOK_RESPONSE_SCHEMA = { + properties: { + ok: { + type: "boolean", + description: "true to allow the action, false to block it", + }, + reason: { + type: "string", + description: "Explanation for the decision. Required when ok is false.", + }, + }, + required: ["ok"], +}; + +/** Response shape from POST /v1/agents/{agent_id}/generate */ +interface GenerateResponse { + content: string; + model: string; + usage: { + completion_tokens: number; + prompt_tokens: number; + total_tokens: number; + }; +} + +/** + * Execute a prompt-based hook by sending the hook input to an LLM + * via the POST /v1/agents/{agent_id}/generate endpoint. + */ +export async function executePromptHook( + hook: PromptHookConfig, + input: HookInput, + _workingDirectory: string = process.cwd(), +): Promise { + const startTime = Date.now(); + + try { + const agentId = getAgentId(input); + if (!agentId) { + throw new Error( + "Prompt hooks require an agent_id. Ensure the hook event provides an agent_id " + + "or set the LETTA_AGENT_ID environment variable.", + ); + } + + // Build the user prompt with $ARGUMENTS replaced + const userPrompt = buildPrompt(hook.prompt, input); + const timeout = hook.timeout ?? DEFAULT_PROMPT_TIMEOUT_MS; + + // Call the generate endpoint (uses agent's model unless hook overrides) + const llmResponse = await callGenerateEndpoint( + agentId, + PROMPT_HOOK_SYSTEM, + userPrompt, + hook.model, + timeout, + ); + + // Parse the response + const parsedResponse = parsePromptResponse(llmResponse); + const durationMs = Date.now() - startTime; + + // Log hook completion (matching command hook format from executor.ts) + const shouldBlock = parsedResponse.ok !== true; + const exitCode = shouldBlock ? 2 : 0; + const exitColor = shouldBlock ? "\x1b[31m" : "\x1b[32m"; + const exitLabel = `${exitColor}exit ${exitCode}\x1b[0m`; + const promptLabel = `\x1b[38;2;140;140;249m✦\x1b[90m ${hook.prompt.slice(0, 50)}${hook.prompt.length > 50 ? "..." : ""}`; + console.log(`\x1b[90m[hook:${input.event_type}] ${promptLabel}\x1b[0m`); + console.log(`\x1b[90m \u23BF ${exitLabel} (${durationMs}ms)\x1b[0m`); + // Show the JSON response as stdout + const responseJson = JSON.stringify(parsedResponse); + console.log(`\x1b[90m \u23BF (stdout)\x1b[0m`); + console.log(`\x1b[90m ${responseJson}\x1b[0m`); + + return responseToHookResult(parsedResponse, durationMs); + } catch (error) { + const durationMs = Date.now() - startTime; + const errorMessage = error instanceof Error ? error.message : String(error); + const timedOut = errorMessage.includes("timed out"); + + const promptLabel = `\x1b[38;2;140;140;249m✦\x1b[90m ${hook.prompt.slice(0, 50)}${hook.prompt.length > 50 ? "..." : ""}`; + console.log(`\x1b[90m[hook:${input.event_type}] ${promptLabel}\x1b[0m`); + console.log( + `\x1b[90m \u23BF \x1b[33mexit 1\x1b[0m (${durationMs}ms)\x1b[0m`, + ); + console.log(`\x1b[90m \u23BF (stderr)\x1b[0m`); + console.log(`\x1b[90m ${errorMessage}\x1b[0m`); + + return { + exitCode: HookExitCode.ERROR, + stdout: "", + stderr: errorMessage, + timedOut, + durationMs, + error: errorMessage, + }; + } +} + +/** + * Call the POST /v1/agents/{agent_id}/generate endpoint for hook evaluation. + * Uses the Letta SDK client's raw post() method since the SDK doesn't have + * a typed generate() method yet. + */ +async function callGenerateEndpoint( + agentId: string, + systemPrompt: string, + userPrompt: string, + overrideModel: string | undefined, + timeout: number, +): Promise { + const client = await getClient(); + + const response = await client.post( + `/v1/agents/${agentId}/generate`, + { + body: { + prompt: userPrompt, + system_prompt: systemPrompt, + ...(overrideModel && { override_model: overrideModel }), + response_schema: PROMPT_HOOK_RESPONSE_SCHEMA, + }, + timeout, + }, + ); + + return response.content; +} diff --git a/src/hooks/types.ts b/src/hooks/types.ts index 899f54f..e0a8489 100644 --- a/src/hooks/types.ts +++ b/src/hooks/types.ts @@ -29,17 +29,82 @@ export type SimpleHookEvent = export type HookEvent = ToolHookEvent | SimpleHookEvent; /** - * Individual hook command configuration + * Command hook configuration - executes a shell command */ -export interface HookCommand { - /** Type of hook - currently only "command" is supported */ +export interface CommandHookConfig { + /** Type of hook */ type: "command"; /** Shell command to execute */ command: string; - /** Optional timeout in milliseconds (default: 60000) */ + /** Optional timeout in milliseconds (default: 60000 for command hooks) */ timeout?: number; } +/** + * Prompt hook configuration - sends hook input to an LLM for evaluation. + * Supported events: PreToolUse, PostToolUse, PostToolUseFailure, + * PermissionRequest, UserPromptSubmit, Stop, and SubagentStop. + */ +export interface PromptHookConfig { + /** Type of hook */ + type: "prompt"; + /** + * Prompt text to send to the model. + * Use $ARGUMENTS as a placeholder for the hook input JSON. + */ + prompt: string; + /** Optional model to use for evaluation */ + model?: string; + /** Optional timeout in milliseconds (default: 30000 for prompt hooks) */ + timeout?: number; +} + +/** + * Placeholder for $ARGUMENTS in prompt hooks + */ +export const PROMPT_ARGUMENTS_PLACEHOLDER = "$ARGUMENTS"; + +/** + * Events that support prompt-based hooks: + * PreToolUse, PostToolUse, PostToolUseFailure, PermissionRequest, + * UserPromptSubmit, Stop, SubagentStop + */ +export const PROMPT_HOOK_SUPPORTED_EVENTS: Set = new Set([ + "PreToolUse", + "PostToolUse", + "PostToolUseFailure", + "PermissionRequest", + "UserPromptSubmit", + "Stop", + "SubagentStop", +]); + +/** + * Type guard to check if an event supports prompt hooks + */ +export function supportsPromptHooks(event: HookEvent): boolean { + return PROMPT_HOOK_SUPPORTED_EVENTS.has(event); +} + +/** + * Individual hook configuration - can be command or prompt type + */ +export type HookCommand = CommandHookConfig | PromptHookConfig; + +/** + * Type guard to check if a hook is a command hook + */ +export function isCommandHook(hook: HookCommand): hook is CommandHookConfig { + return hook.type === "command"; +} + +/** + * Type guard to check if a hook is a prompt hook + */ +export function isPromptHook(hook: HookCommand): hook is PromptHookConfig { + return hook.type === "prompt"; +} + /** * Hook matcher configuration for tool events - matches hooks to specific tools */ @@ -125,6 +190,17 @@ export interface HookResult { error?: string; } +/** + * Expected JSON response structure from prompt hooks. + * The LLM must respond with this schema per Claude Code spec. + */ +export interface PromptHookResponse { + /** true allows the action, false prevents it */ + ok: boolean; + /** Required when ok is false. Explanation shown to Claude. */ + reason?: string; +} + /** * Aggregated result from running all matched hooks */ @@ -281,6 +357,8 @@ export interface StopHookInput extends HookInputBase { preceding_reasoning?: string; /** The assistant's final message content */ assistant_message?: string; + /** The user's original prompt that initiated this turn */ + user_message?: string; } /** diff --git a/src/tests/hooks/loader.test.ts b/src/tests/hooks/loader.test.ts index 589c858..06b59fc 100644 --- a/src/tests/hooks/loader.test.ts +++ b/src/tests/hooks/loader.test.ts @@ -12,6 +12,8 @@ import { mergeHooksConfigs, } from "../../hooks/loader"; import { + type CommandHookConfig, + type HookCommand, type HookEvent, type HooksConfig, isToolEvent, @@ -20,6 +22,16 @@ import { } from "../../hooks/types"; import { settingsManager } from "../../settings-manager"; +// Type-safe helper to extract command from a hook (tests only use command hooks) +function asCommand( + hook: HookCommand | undefined, +): CommandHookConfig | undefined { + if (hook && hook.type === "command") { + return hook as CommandHookConfig; + } + return undefined; +} + describe("Hooks Loader", () => { let tempDir: string; let fakeHome: string; @@ -215,11 +227,11 @@ describe("Hooks Loader", () => { const bashHooks = getMatchingHooks(config, "PreToolUse", "Bash"); expect(bashHooks).toHaveLength(1); - expect(bashHooks[0]?.command).toBe("bash hook"); + expect(asCommand(bashHooks[0])?.command).toBe("bash hook"); const editHooks = getMatchingHooks(config, "PreToolUse", "Edit"); expect(editHooks).toHaveLength(1); - expect(editHooks[0]?.command).toBe("edit hook"); + expect(asCommand(editHooks[0])?.command).toBe("edit hook"); }); test("returns wildcard hooks for any tool", () => { @@ -234,7 +246,7 @@ describe("Hooks Loader", () => { const hooks = getMatchingHooks(config, "PreToolUse", "AnyTool"); expect(hooks).toHaveLength(1); - expect(hooks[0]?.command).toBe("all tools hook"); + expect(asCommand(hooks[0])?.command).toBe("all tools hook"); }); test("returns multiple matching hooks", () => { @@ -315,9 +327,9 @@ describe("Hooks Loader", () => { const hooks = getMatchingHooks(config, "PreToolUse", "Bash"); expect(hooks).toHaveLength(3); - expect(hooks[0]?.command).toBe("multi tool"); - expect(hooks[1]?.command).toBe("bash specific"); - expect(hooks[2]?.command).toBe("wildcard"); + expect(asCommand(hooks[0])?.command).toBe("multi tool"); + expect(asCommand(hooks[1])?.command).toBe("bash specific"); + expect(asCommand(hooks[2])?.command).toBe("wildcard"); }); }); @@ -496,7 +508,9 @@ describe("Hooks Loader", () => { const hooks = await loadProjectLocalHooks(tempDir); expect(hooks.PreToolUse).toHaveLength(1); - expect(hooks.PreToolUse?.[0]?.hooks[0]?.command).toBe("echo local"); + expect(asCommand(hooks.PreToolUse?.[0]?.hooks[0])?.command).toBe( + "echo local", + ); }); }); @@ -517,8 +531,12 @@ describe("Hooks Loader", () => { const merged = mergeHooksConfigs(global, project, projectLocal); expect(merged.PreToolUse).toHaveLength(2); - expect(merged.PreToolUse?.[0]?.hooks[0]?.command).toBe("local"); // Local first - expect(merged.PreToolUse?.[1]?.hooks[0]?.command).toBe("project"); // Project second + expect(asCommand(merged.PreToolUse?.[0]?.hooks[0])?.command).toBe( + "local", + ); // Local first + expect(asCommand(merged.PreToolUse?.[1]?.hooks[0])?.command).toBe( + "project", + ); // Project second }); test("project-local hooks run before global hooks", () => { @@ -537,8 +555,12 @@ describe("Hooks Loader", () => { const merged = mergeHooksConfigs(global, project, projectLocal); expect(merged.PreToolUse).toHaveLength(2); - expect(merged.PreToolUse?.[0]?.hooks[0]?.command).toBe("local"); // Local first - expect(merged.PreToolUse?.[1]?.hooks[0]?.command).toBe("global"); // Global last + expect(asCommand(merged.PreToolUse?.[0]?.hooks[0])?.command).toBe( + "local", + ); // Local first + expect(asCommand(merged.PreToolUse?.[1]?.hooks[0])?.command).toBe( + "global", + ); // Global last }); test("all three levels merge correctly", () => { @@ -577,9 +599,15 @@ describe("Hooks Loader", () => { // PreToolUse: local -> project -> global expect(merged.PreToolUse).toHaveLength(3); - expect(merged.PreToolUse?.[0]?.hooks[0]?.command).toBe("local"); - expect(merged.PreToolUse?.[1]?.hooks[0]?.command).toBe("project"); - expect(merged.PreToolUse?.[2]?.hooks[0]?.command).toBe("global"); + expect(asCommand(merged.PreToolUse?.[0]?.hooks[0])?.command).toBe( + "local", + ); + expect(asCommand(merged.PreToolUse?.[1]?.hooks[0])?.command).toBe( + "project", + ); + expect(asCommand(merged.PreToolUse?.[2]?.hooks[0])?.command).toBe( + "global", + ); // Others only have one source expect(merged.PostToolUse).toHaveLength(1); @@ -624,8 +652,10 @@ describe("Hooks Loader", () => { // Local should come before project expect(hooks.PreToolUse).toHaveLength(2); - expect(hooks.PreToolUse?.[0]?.hooks[0]?.command).toBe("local"); - expect(hooks.PreToolUse?.[1]?.hooks[0]?.command).toBe("project"); + expect(asCommand(hooks.PreToolUse?.[0]?.hooks[0])?.command).toBe("local"); + expect(asCommand(hooks.PreToolUse?.[1]?.hooks[0])?.command).toBe( + "project", + ); }); test("handles missing local settings gracefully", async () => { @@ -650,7 +680,9 @@ describe("Hooks Loader", () => { const hooks = await loadHooks(tempDir); expect(hooks.PreToolUse).toHaveLength(1); - expect(hooks.PreToolUse?.[0]?.hooks[0]?.command).toBe("project"); + expect(asCommand(hooks.PreToolUse?.[0]?.hooks[0])?.command).toBe( + "project", + ); }); }); }); diff --git a/src/tests/hooks/prompt-executor.test.ts b/src/tests/hooks/prompt-executor.test.ts new file mode 100644 index 0000000..f7f0dbe --- /dev/null +++ b/src/tests/hooks/prompt-executor.test.ts @@ -0,0 +1,425 @@ +import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { executePromptHook } from "../../hooks/prompt-executor"; +import { + HookExitCode, + type PreToolUseHookInput, + type StopHookInput, +} from "../../hooks/types"; + +interface GenerateOpts { + body: { + prompt: string; + system_prompt: string; + override_model?: string; + response_schema?: Record; + }; + timeout?: number; +} + +// Mock getClient to avoid real API calls +const mockPost = mock( + (_path: string, _opts: GenerateOpts) => + Promise.resolve({ + content: '{"ok": true}', + model: "test-model", + usage: { + completion_tokens: 10, + prompt_tokens: 50, + total_tokens: 60, + }, + }) as Promise>, +); +const mockGetClient = mock(() => Promise.resolve({ post: mockPost })); + +// Mock getCurrentAgentId +const mockGetCurrentAgentId = mock(() => "agent-test-123"); + +mock.module("../../agent/client", () => ({ + getClient: mockGetClient, +})); + +mock.module("../../agent/context", () => ({ + getCurrentAgentId: mockGetCurrentAgentId, +})); + +/** Helper to get the first call's [path, opts] from mockPost */ +function firstPostCall(): [string, GenerateOpts] { + const calls = mockPost.mock.calls; + const call = calls[0]; + if (!call) throw new Error("mockPost was not called"); + return call; +} + +describe("Prompt Hook Executor", () => { + beforeEach(() => { + mockPost.mockClear(); + mockGetClient.mockClear(); + mockGetCurrentAgentId.mockClear(); + + // Default: allow + mockPost.mockResolvedValue({ + content: '{"ok": true}', + model: "anthropic/claude-3-5-haiku-20241022", + usage: { + completion_tokens: 10, + prompt_tokens: 50, + total_tokens: 60, + }, + }); + }); + + afterEach(() => { + // Clean up env vars + delete process.env.LETTA_AGENT_ID; + }); + + describe("executePromptHook", () => { + test("calls generate endpoint and returns ALLOW when ok is true", async () => { + const hook = { + type: "prompt" as const, + prompt: "Check if this tool call is safe", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "ls" }, + agent_id: "agent-abc-123", + }; + + const result = await executePromptHook(hook, input); + + expect(result.exitCode).toBe(HookExitCode.ALLOW); + expect(result.timedOut).toBe(false); + expect(mockGetClient).toHaveBeenCalledTimes(1); + expect(mockPost).toHaveBeenCalledTimes(1); + + // Verify the correct path and body were sent + const [path, opts] = firstPostCall(); + expect(path).toBe("/v1/agents/agent-abc-123/generate"); + expect(opts.body.prompt).toContain("Check if this tool call is safe"); + expect(opts.body.system_prompt).toBeTruthy(); + expect(opts.body.override_model).toBeUndefined(); + expect(opts.body.response_schema).toBeDefined(); + const schema = opts.body.response_schema as { + properties: { ok: { type: string } }; + }; + expect(schema.properties.ok.type).toBe("boolean"); + }); + + test("returns BLOCK when ok is false", async () => { + mockPost.mockResolvedValue({ + content: '{"ok": false, "reason": "Dangerous command detected"}', + model: "anthropic/claude-3-5-haiku-20241022", + usage: { + completion_tokens: 15, + prompt_tokens: 50, + total_tokens: 65, + }, + }); + + const hook = { + type: "prompt" as const, + prompt: "Block dangerous commands", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "rm -rf /" }, + agent_id: "agent-abc-123", + }; + + const result = await executePromptHook(hook, input); + + expect(result.exitCode).toBe(HookExitCode.BLOCK); + expect(result.stderr).toBe("Dangerous command detected"); + }); + + test("uses custom model when specified in hook config", async () => { + const hook = { + type: "prompt" as const, + prompt: "Evaluate this action", + model: "openai/gpt-4o", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Edit", + tool_input: { file_path: "/etc/passwd" }, + agent_id: "agent-abc-123", + }; + + await executePromptHook(hook, input); + + const [, opts] = firstPostCall(); + expect(opts.body.override_model).toBe("openai/gpt-4o"); + }); + + test("uses custom timeout when specified", async () => { + const hook = { + type: "prompt" as const, + prompt: "Evaluate this action", + timeout: 5000, + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "echo hi" }, + agent_id: "agent-abc-123", + }; + + await executePromptHook(hook, input); + + const [, opts] = firstPostCall(); + expect(opts.timeout).toBe(5000); + }); + + test("replaces $ARGUMENTS placeholder in prompt", async () => { + const hook = { + type: "prompt" as const, + prompt: 'Check if tool "$ARGUMENTS" is safe to run', + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "ls" }, + agent_id: "agent-abc-123", + }; + + await executePromptHook(hook, input); + + const [, opts] = firstPostCall(); + // $ARGUMENTS should have been replaced with JSON + expect(opts.body.prompt).not.toContain("$ARGUMENTS"); + expect(opts.body.prompt).toContain('"event_type": "PreToolUse"'); + expect(opts.body.prompt).toContain('"tool_name": "Bash"'); + }); + + test("appends hook input when $ARGUMENTS is not in prompt", async () => { + const hook = { + type: "prompt" as const, + prompt: "Is this tool call safe?", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "ls" }, + agent_id: "agent-abc-123", + }; + + await executePromptHook(hook, input); + + const [, opts] = firstPostCall(); + expect(opts.body.prompt).toContain("Is this tool call safe?"); + expect(opts.body.prompt).toContain("Hook input:"); + expect(opts.body.prompt).toContain('"tool_name": "Bash"'); + }); + + test("falls back to getCurrentAgentId when input has no agent_id", async () => { + mockGetCurrentAgentId.mockReturnValue("agent-from-context"); + + const hook = { + type: "prompt" as const, + prompt: "Check this", + }; + const input: StopHookInput = { + event_type: "Stop", + working_directory: "/tmp", + stop_reason: "end_turn", + }; + + await executePromptHook(hook, input); + + const [path] = firstPostCall(); + expect(path).toBe("/v1/agents/agent-from-context/generate"); + }); + + test("falls back to LETTA_AGENT_ID env var when context unavailable", async () => { + mockGetCurrentAgentId.mockImplementation(() => { + throw new Error("No agent context set"); + }); + process.env.LETTA_AGENT_ID = "agent-from-env"; + + const hook = { + type: "prompt" as const, + prompt: "Check this", + }; + const input: StopHookInput = { + event_type: "Stop", + working_directory: "/tmp", + stop_reason: "end_turn", + }; + + await executePromptHook(hook, input); + + const [path] = firstPostCall(); + expect(path).toBe("/v1/agents/agent-from-env/generate"); + }); + + test("returns ERROR when no agent_id available", async () => { + mockGetCurrentAgentId.mockImplementation(() => { + throw new Error("No agent context set"); + }); + delete process.env.LETTA_AGENT_ID; + + const hook = { + type: "prompt" as const, + prompt: "Check this", + }; + const input: StopHookInput = { + event_type: "Stop", + working_directory: "/tmp", + stop_reason: "end_turn", + }; + + const result = await executePromptHook(hook, input); + + expect(result.exitCode).toBe(HookExitCode.ERROR); + expect(result.error).toContain("agent_id"); + expect(mockPost).not.toHaveBeenCalled(); + }); + + test("returns ERROR when API call fails", async () => { + mockPost.mockRejectedValue(new Error("Network error")); + + const hook = { + type: "prompt" as const, + prompt: "Check this", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "ls" }, + agent_id: "agent-abc-123", + }; + + const result = await executePromptHook(hook, input); + + expect(result.exitCode).toBe(HookExitCode.ERROR); + expect(result.error).toContain("Network error"); + }); + + test("returns ERROR when LLM returns unparseable response", async () => { + mockPost.mockResolvedValue({ + content: "This is not valid JSON at all", + model: "anthropic/claude-3-5-haiku-20241022", + usage: { + completion_tokens: 10, + prompt_tokens: 50, + total_tokens: 60, + }, + }); + + const hook = { + type: "prompt" as const, + prompt: "Check this", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "ls" }, + agent_id: "agent-abc-123", + }; + + const result = await executePromptHook(hook, input); + + expect(result.exitCode).toBe(HookExitCode.ERROR); + expect(result.error).toContain("Failed to parse"); + }); + + test("handles JSON wrapped in markdown code blocks", async () => { + mockPost.mockResolvedValue({ + content: '```json\n{"ok": true}\n```', + model: "anthropic/claude-3-5-haiku-20241022", + usage: { + completion_tokens: 10, + prompt_tokens: 50, + total_tokens: 60, + }, + }); + + const hook = { + type: "prompt" as const, + prompt: "Check this", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "ls" }, + agent_id: "agent-abc-123", + }; + + const result = await executePromptHook(hook, input); + + expect(result.exitCode).toBe(HookExitCode.ALLOW); + }); + + test("returns ERROR when ok is not a boolean", async () => { + mockPost.mockResolvedValue({ + content: '{"ok": "yes", "reason": "looks fine"}', + model: "anthropic/claude-3-5-haiku-20241022", + usage: { + completion_tokens: 10, + prompt_tokens: 50, + total_tokens: 60, + }, + }); + + const hook = { + type: "prompt" as const, + prompt: "Check this", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "ls" }, + agent_id: "agent-abc-123", + }; + + const result = await executePromptHook(hook, input); + + expect(result.exitCode).toBe(HookExitCode.ERROR); + expect(result.error).toContain('"ok" must be a boolean'); + }); + + test("sends response_schema for structured output", async () => { + const hook = { + type: "prompt" as const, + prompt: "Is this safe?", + }; + const input: PreToolUseHookInput = { + event_type: "PreToolUse", + working_directory: "/tmp", + tool_name: "Bash", + tool_input: { command: "ls" }, + agent_id: "agent-abc-123", + }; + + await executePromptHook(hook, input); + + const [, opts] = firstPostCall(); + expect(opts.body.response_schema).toEqual({ + properties: { + ok: { + type: "boolean", + description: "true to allow the action, false to block it", + }, + reason: { + type: "string", + description: + "Explanation for the decision. Required when ok is false.", + }, + }, + required: ["ok"], + }); + }); + }); +}); diff --git a/src/tests/settings-manager.test.ts b/src/tests/settings-manager.test.ts index fae5fdb..52ef3cf 100644 --- a/src/tests/settings-manager.test.ts +++ b/src/tests/settings-manager.test.ts @@ -2,7 +2,19 @@ import { afterEach, beforeEach, describe, expect, test } from "bun:test"; import { mkdtemp, rm } from "node:fs/promises"; import { tmpdir } from "node:os"; import { join } from "node:path"; +import type { CommandHookConfig, HookCommand } from "../hooks/types"; import { settingsManager } from "../settings-manager"; + +// Type-safe helper to extract command from a hook (tests only use command hooks) +function asCommand( + hook: HookCommand | undefined, +): CommandHookConfig | undefined { + if (hook && hook.type === "command") { + return hook as CommandHookConfig; + } + return undefined; +} + import { deleteSecureTokens, isKeychainAvailable, @@ -558,7 +570,7 @@ describe("Settings Manager - Hooks", () => { const settings = settingsManager.getSettings(); expect(settings.hooks?.PreToolUse).toHaveLength(1); - expect(settings.hooks?.PreToolUse?.[0]?.hooks[0]?.command).toBe( + expect(asCommand(settings.hooks?.PreToolUse?.[0]?.hooks[0])?.command).toBe( "echo persisted", ); expect(settings.hooks?.SessionStart).toHaveLength(1); @@ -631,7 +643,9 @@ describe("Settings Manager - Hooks", () => { expect(reloaded.hooks?.Stop).toHaveLength(1); // Simple event hooks are in SimpleHookMatcher format with hooks array - expect(reloaded.hooks?.Stop?.[0]?.hooks[0]?.command).toBe("echo stop-hook"); + expect(asCommand(reloaded.hooks?.Stop?.[0]?.hooks[0])?.command).toBe( + "echo stop-hook", + ); }); test("All 11 hook event types can be configured", async () => {