feat(remote): support per-conversation working directories in listener mode (#1323)

This commit is contained in:
Charles Packer
2026-03-10 13:42:42 -07:00
committed by GitHub
parent e82a2d33f8
commit 4c9f63c4e2
13 changed files with 482 additions and 45 deletions

View File

@@ -10,6 +10,7 @@ import type { ToolReturnMessage } from "@letta-ai/letta-client/resources/tools";
import type { ApprovalRequest } from "../cli/helpers/stream"; import type { ApprovalRequest } from "../cli/helpers/stream";
import { INTERRUPTED_BY_USER } from "../constants"; import { INTERRUPTED_BY_USER } from "../constants";
import { import {
captureToolExecutionContext,
executeTool, executeTool,
type ToolExecutionResult, type ToolExecutionResult,
type ToolReturnContent, type ToolReturnContent,
@@ -135,6 +136,7 @@ const GLOBAL_LOCK_TOOLS = new Set([
export function getResourceKey( export function getResourceKey(
toolName: string, toolName: string,
toolArgs: Record<string, unknown>, toolArgs: Record<string, unknown>,
workingDirectory: string = process.env.USER_CWD || process.cwd(),
): string { ): string {
// Global lock tools serialize with everything // Global lock tools serialize with everything
if (GLOBAL_LOCK_TOOLS.has(toolName)) { if (GLOBAL_LOCK_TOOLS.has(toolName)) {
@@ -146,10 +148,9 @@ export function getResourceKey(
const filePath = toolArgs.file_path; const filePath = toolArgs.file_path;
if (typeof filePath === "string") { if (typeof filePath === "string") {
// Normalize to absolute path for consistent comparison // Normalize to absolute path for consistent comparison
const userCwd = process.env.USER_CWD || process.cwd();
return path.isAbsolute(filePath) return path.isAbsolute(filePath)
? path.normalize(filePath) ? path.normalize(filePath)
: path.resolve(userCwd, filePath); : path.resolve(workingDirectory, filePath);
} }
} }
@@ -360,8 +361,15 @@ export async function executeApprovalBatch(
isStderr?: boolean, isStderr?: boolean,
) => void; ) => void;
toolContextId?: string; toolContextId?: string;
workingDirectory?: string;
}, },
): Promise<ApprovalResult[]> { ): Promise<ApprovalResult[]> {
const toolContextId =
options?.toolContextId ??
(options?.workingDirectory
? captureToolExecutionContext(options.workingDirectory).contextId
: undefined);
// Pre-allocate results array to maintain original order // Pre-allocate results array to maintain original order
const results: (ApprovalResult | null)[] = new Array(decisions.length).fill( const results: (ApprovalResult | null)[] = new Array(decisions.length).fill(
null, null,
@@ -399,7 +407,11 @@ export async function executeApprovalBatch(
} else { } else {
args = decision.approval.toolArgs || {}; args = decision.approval.toolArgs || {};
} }
const resourceKey = getResourceKey(toolName, args); const resourceKey = getResourceKey(
toolName,
args,
options?.workingDirectory,
);
const indices = writeToolsByResource.get(resourceKey) || []; const indices = writeToolsByResource.get(resourceKey) || [];
indices.push(i); indices.push(i);
@@ -411,7 +423,10 @@ export async function executeApprovalBatch(
const execute = async (i: number) => { const execute = async (i: number) => {
const decision = decisions[i]; const decision = decisions[i];
if (decision) { if (decision) {
results[i] = await executeSingleDecision(decision, onChunk, options); results[i] = await executeSingleDecision(decision, onChunk, {
...options,
toolContextId,
});
} }
}; };
@@ -456,6 +471,7 @@ export async function executeAutoAllowedTools(
isStderr?: boolean, isStderr?: boolean,
) => void; ) => void;
toolContextId?: string; toolContextId?: string;
workingDirectory?: string;
}, },
): Promise<AutoAllowedResult[]> { ): Promise<AutoAllowedResult[]> {
const decisions: ApprovalDecision[] = autoAllowed.map((ac) => ({ const decisions: ApprovalDecision[] = autoAllowed.map((ac) => ({

View File

@@ -479,14 +479,11 @@ export async function getResumeData(
// may not support this pattern) // may not support this pattern)
if (includeMessageHistory && isBackfillEnabled()) { if (includeMessageHistory && isBackfillEnabled()) {
try { try {
const messagesPage = await client.conversations.messages.list( const messagesPage = await client.agents.messages.list(agent.id, {
"default", conversation_id: "default",
{ limit: BACKFILL_PAGE_LIMIT,
limit: BACKFILL_PAGE_LIMIT, order: "desc",
order: "desc", });
agent_id: agent.id,
},
);
messages = sortChronological(messagesPage.getPaginatedItems()); messages = sortChronological(messagesPage.getPaginatedItems());
if (process.env.DEBUG) { if (process.env.DEBUG) {

View File

@@ -55,6 +55,7 @@ export type SendMessageStreamOptions = {
background?: boolean; background?: boolean;
agentId?: string; // Required when conversationId is "default" agentId?: string; // Required when conversationId is "default"
approvalNormalization?: ApprovalNormalizationOptions; approvalNormalization?: ApprovalNormalizationOptions;
workingDirectory?: string;
}; };
export function buildConversationMessagesCreateRequestBody( export function buildConversationMessagesCreateRequestBody(
@@ -118,7 +119,9 @@ export async function sendMessageStream(
// Wait for any in-progress toolset switch to complete before reading tools // Wait for any in-progress toolset switch to complete before reading tools
// This prevents sending messages with stale tools during a switch // This prevents sending messages with stale tools during a switch
await waitForToolsetReady(); await waitForToolsetReady();
const { clientTools, contextId } = captureToolExecutionContext(); const { clientTools, contextId } = captureToolExecutionContext(
opts.workingDirectory,
);
const { clientSkills, errors: clientSkillDiscoveryErrors } = const { clientSkills, errors: clientSkillDiscoveryErrors } =
await buildClientSkillsPayload({ await buildClientSkillsPayload({
agentId: opts.agentId, agentId: opts.agentId,

View File

@@ -278,7 +278,8 @@ export async function recompileAgentSystemPrompt(
options: RecompileAgentSystemPromptOptions = {}, options: RecompileAgentSystemPromptOptions = {},
clientOverride?: AgentSystemPromptRecompileClient, clientOverride?: AgentSystemPromptRecompileClient,
): Promise<string> { ): Promise<string> {
const client = clientOverride ?? (await getClient()); const client = (clientOverride ??
(await getClient())) as AgentSystemPromptRecompileClient;
return client.agents.recompile(agentId, { return client.agents.recompile(agentId, {
dry_run: options.dryRun, dry_run: options.dryRun,

View File

@@ -242,14 +242,11 @@ export function ConversationSelector({
let defaultConversation: EnrichedConversation | null = null; let defaultConversation: EnrichedConversation | null = null;
if (!afterCursor) { if (!afterCursor) {
try { try {
const defaultMessages = await client.conversations.messages.list( const defaultMessages = await client.agents.messages.list(agentId, {
"default", conversation_id: "default",
{ limit: 20,
limit: 20, order: "desc",
order: "desc", });
agent_id: agentId,
},
);
const defaultMsgItems = defaultMessages.getPaginatedItems(); const defaultMsgItems = defaultMessages.getPaginatedItems();
if (defaultMsgItems.length > 0) { if (defaultMsgItems.length > 0) {
const defaultStats = getMessageStats( const defaultStats = getMessageStats(

View File

@@ -24,6 +24,7 @@ export type ClassifyApprovalsOptions<TContext = ApprovalContext | null> = {
getContext?: ( getContext?: (
toolName: string, toolName: string,
parsedArgs: Record<string, unknown>, parsedArgs: Record<string, unknown>,
workingDirectory?: string,
) => Promise<TContext>; ) => Promise<TContext>;
alwaysRequiresUserInput?: (toolName: string) => boolean; alwaysRequiresUserInput?: (toolName: string) => boolean;
treatAskAsDeny?: boolean; treatAskAsDeny?: boolean;
@@ -31,6 +32,7 @@ export type ClassifyApprovalsOptions<TContext = ApprovalContext | null> = {
missingNameReason?: string; missingNameReason?: string;
requireArgsForAutoApprove?: boolean; requireArgsForAutoApprove?: boolean;
missingArgsReason?: (missing: string[]) => string; missingArgsReason?: (missing: string[]) => string;
workingDirectory?: string;
}; };
export async function getMissingRequiredArgs( export async function getMissingRequiredArgs(
@@ -74,9 +76,13 @@ export async function classifyApprovals<TContext = ApprovalContext | null>(
approval.toolArgs || "{}", approval.toolArgs || "{}",
{}, {},
); );
const permission = await checkToolPermission(toolName, parsedArgs); const permission = await checkToolPermission(
toolName,
parsedArgs,
opts.workingDirectory,
);
const context = opts.getContext const context = opts.getContext
? await opts.getContext(toolName, parsedArgs) ? await opts.getContext(toolName, parsedArgs, opts.workingDirectory)
: null; : null;
let decision = permission.decision; let decision = permission.decision;

View File

@@ -159,12 +159,12 @@ export async function runMessagesSubcommand(argv: string[]): Promise<number> {
return 1; return 1;
} }
const response = await client.conversations.messages.list("default", { const response = await client.agents.messages.list(agentId, {
conversation_id: "default",
limit: parseLimit(parsed.values.limit, 20), limit: parseLimit(parsed.values.limit, 20),
after: parsed.values.after, after: parsed.values.after,
before: parsed.values.before, before: parsed.values.before,
order, order,
agent_id: agentId,
}); });
const messages = response.getPaginatedItems() ?? []; const messages = response.getPaginatedItems() ?? [];

View File

@@ -4,7 +4,7 @@
* and only sends hunks, which is sufficient for rendering. * and only sends hunks, which is sufficient for rendering.
*/ */
import { basename } from "node:path"; import path, { basename } from "node:path";
import type { AdvancedDiffResult, AdvancedHunk } from "../cli/helpers/diff"; import type { AdvancedDiffResult, AdvancedHunk } from "../cli/helpers/diff";
import type { DiffHunk, DiffHunkLine, DiffPreview } from "../types/protocol"; import type { DiffHunk, DiffHunkLine, DiffPreview } from "../types/protocol";
@@ -124,6 +124,7 @@ async function getDiffDeps(): Promise<DiffDeps> {
export async function computeDiffPreviews( export async function computeDiffPreviews(
toolName: string, toolName: string,
toolArgs: Record<string, unknown>, toolArgs: Record<string, unknown>,
workingDirectory: string = process.env.USER_CWD || process.cwd(),
): Promise<DiffPreview[]> { ): Promise<DiffPreview[]> {
const { const {
computeAdvancedDiff, computeAdvancedDiff,
@@ -139,9 +140,12 @@ export async function computeDiffPreviews(
if (isFileWriteTool(toolName)) { if (isFileWriteTool(toolName)) {
const filePath = toolArgs.file_path as string | undefined; const filePath = toolArgs.file_path as string | undefined;
if (filePath) { if (filePath) {
const resolvedFilePath = path.isAbsolute(filePath)
? filePath
: path.resolve(workingDirectory, filePath);
const result = computeAdvancedDiff({ const result = computeAdvancedDiff({
kind: "write", kind: "write",
filePath, filePath: resolvedFilePath,
content: (toolArgs.content as string) || "", content: (toolArgs.content as string) || "",
}); });
previews.push(toDiffPreview(result, basename(filePath))); previews.push(toDiffPreview(result, basename(filePath)));
@@ -149,10 +153,13 @@ export async function computeDiffPreviews(
} else if (isFileEditTool(toolName)) { } else if (isFileEditTool(toolName)) {
const filePath = toolArgs.file_path as string | undefined; const filePath = toolArgs.file_path as string | undefined;
if (filePath) { if (filePath) {
const resolvedFilePath = path.isAbsolute(filePath)
? filePath
: path.resolve(workingDirectory, filePath);
if (toolArgs.edits && Array.isArray(toolArgs.edits)) { if (toolArgs.edits && Array.isArray(toolArgs.edits)) {
const result = computeAdvancedDiff({ const result = computeAdvancedDiff({
kind: "multi_edit", kind: "multi_edit",
filePath, filePath: resolvedFilePath,
edits: toolArgs.edits as Array<{ edits: toolArgs.edits as Array<{
old_string: string; old_string: string;
new_string: string; new_string: string;
@@ -163,7 +170,7 @@ export async function computeDiffPreviews(
} else { } else {
const result = computeAdvancedDiff({ const result = computeAdvancedDiff({
kind: "edit", kind: "edit",
filePath, filePath: resolvedFilePath,
oldString: (toolArgs.old_string as string) || "", oldString: (toolArgs.old_string as string) || "",
newString: (toolArgs.new_string as string) || "", newString: (toolArgs.new_string as string) || "",
replaceAll: toolArgs.replace_all as boolean | undefined, replaceAll: toolArgs.replace_all as boolean | undefined,

View File

@@ -103,7 +103,7 @@ describe("getResumeData", () => {
const conversationsRetrieve = mock(async () => ({ const conversationsRetrieve = mock(async () => ({
in_context_message_ids: ["msg-last"], in_context_message_ids: ["msg-last"],
})); }));
const conversationsList = mock(async () => ({ const agentsList = mock(async () => ({
getPaginatedItems: () => [ getPaginatedItems: () => [
makeUserMessage("msg-a"), makeUserMessage("msg-a"),
makeUserMessage("msg-b"), makeUserMessage("msg-b"),
@@ -114,15 +114,15 @@ describe("getResumeData", () => {
const client = { const client = {
conversations: { conversations: {
retrieve: conversationsRetrieve, retrieve: conversationsRetrieve,
messages: { list: conversationsList },
}, },
agents: { messages: { list: agentsList } },
messages: { retrieve: messagesRetrieve }, messages: { retrieve: messagesRetrieve },
} as unknown as Letta; } as unknown as Letta;
const resume = await getResumeData(client, makeAgent(), "default"); const resume = await getResumeData(client, makeAgent(), "default");
expect(messagesRetrieve).toHaveBeenCalledTimes(1); expect(messagesRetrieve).toHaveBeenCalledTimes(1);
expect(conversationsList).toHaveBeenCalledTimes(1); expect(agentsList).toHaveBeenCalledTimes(1);
expect(resume.pendingApprovals).toHaveLength(0); expect(resume.pendingApprovals).toHaveLength(0);
expect(resume.messageHistory.length).toBeGreaterThan(0); expect(resume.messageHistory.length).toBeGreaterThan(0);
}); });

View File

@@ -1,4 +1,7 @@
import { describe, expect, it } from "bun:test"; import { describe, expect, it } from "bun:test";
import { mkdir, mkdtemp, rm, writeFile } from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import type { import type {
AdvancedDiffFallback, AdvancedDiffFallback,
AdvancedDiffSuccess, AdvancedDiffSuccess,
@@ -213,4 +216,32 @@ describe("computeDiffPreviews", () => {
expect(previews).toHaveLength(2); expect(previews).toHaveLength(2);
expect(previews.map((p) => p.fileName).sort()).toEqual(["a.txt", "b.txt"]); expect(previews.map((p) => p.fileName).sort()).toEqual(["a.txt", "b.txt"]);
}); });
it("resolves relative file paths against the provided working directory", async () => {
const tempRoot = await mkdtemp(
path.join(os.tmpdir(), "letta-diff-preview-"),
);
const workspaceDir = path.join(tempRoot, "workspace");
const nestedDir = path.join(workspaceDir, "nested");
const targetFile = path.join(nestedDir, "sample.txt");
await mkdir(nestedDir, { recursive: true });
await writeFile(targetFile, "old content", "utf8");
try {
const previews = await computeDiffPreviews(
"edit",
{
file_path: "nested/sample.txt",
old_string: "old content",
new_string: "new content",
},
workspaceDir,
);
expect(previews).toHaveLength(1);
expect(previews[0]?.mode).toBe("advanced");
expect(previews[0]?.fileName).toBe("sample.txt");
} finally {
await rm(tempRoot, { recursive: true, force: true });
}
});
}); });

View File

@@ -1,4 +1,7 @@
import { describe, expect, test } from "bun:test"; import { describe, expect, test } from "bun:test";
import { mkdir, mkdtemp, realpath, rm } from "node:fs/promises";
import os from "node:os";
import { join } from "node:path";
import type { ApprovalCreate } from "@letta-ai/letta-client/resources/agents/messages"; import type { ApprovalCreate } from "@letta-ai/letta-client/resources/agents/messages";
import WebSocket from "ws"; import WebSocket from "ws";
import { buildConversationMessagesCreateRequestBody } from "../../agent/message"; import { buildConversationMessagesCreateRequestBody } from "../../agent/message";
@@ -254,6 +257,114 @@ describe("listen-client state_response control protocol", () => {
expect(typeof snapshot.cwd).toBe("string"); expect(typeof snapshot.cwd).toBe("string");
expect(snapshot.cwd.length).toBeGreaterThan(0); expect(snapshot.cwd.length).toBeGreaterThan(0);
expect(snapshot.configured_cwd).toBe(snapshot.cwd);
expect(snapshot.active_turn_cwd).toBeNull();
expect(snapshot.cwd_agent_id).toBeNull();
expect(snapshot.cwd_conversation_id).toBe("default");
});
test("scopes configured and active cwd to the requested agent and conversation", () => {
const runtime = __listenClientTestUtils.createRuntime();
__listenClientTestUtils.setConversationWorkingDirectory(
runtime,
"agent-a",
"conv-a",
"/repo/a",
);
__listenClientTestUtils.setConversationWorkingDirectory(
runtime,
"agent-b",
"default",
"/repo/b",
);
runtime.activeAgentId = "agent-a";
runtime.activeConversationId = "conv-a";
runtime.activeWorkingDirectory = "/repo/a";
const activeSnapshot = __listenClientTestUtils.buildStateResponse(
runtime,
2,
"agent-a",
"conv-a",
);
expect(activeSnapshot.configured_cwd).toBe("/repo/a");
expect(activeSnapshot.active_turn_cwd).toBe("/repo/a");
expect(activeSnapshot.cwd_agent_id).toBe("agent-a");
expect(activeSnapshot.cwd_conversation_id).toBe("conv-a");
const defaultSnapshot = __listenClientTestUtils.buildStateResponse(
runtime,
3,
"agent-b",
"default",
);
expect(defaultSnapshot.configured_cwd).toBe("/repo/b");
expect(defaultSnapshot.active_turn_cwd).toBeNull();
expect(defaultSnapshot.cwd_agent_id).toBe("agent-b");
expect(defaultSnapshot.cwd_conversation_id).toBe("default");
});
});
describe("listen-client cwd change handling", () => {
test("resolves relative cwd changes against the conversation cwd and preserves active turn cwd", async () => {
const runtime = __listenClientTestUtils.createRuntime();
const socket = new MockSocket(WebSocket.OPEN);
const tempRoot = await mkdtemp(join(os.tmpdir(), "letta-listen-cwd-"));
const repoDir = join(tempRoot, "repo");
const serverDir = join(repoDir, "server");
const clientDir = join(repoDir, "client");
await mkdir(serverDir, { recursive: true });
await mkdir(clientDir, { recursive: true });
const normalizedServerDir = await realpath(serverDir);
const normalizedClientDir = await realpath(clientDir);
try {
__listenClientTestUtils.setConversationWorkingDirectory(
runtime,
"agent-1",
"conv-1",
normalizedServerDir,
);
runtime.activeAgentId = "agent-1";
runtime.activeConversationId = "conv-1";
runtime.activeWorkingDirectory = normalizedServerDir;
await __listenClientTestUtils.handleCwdChange(
{
type: "change_cwd",
agentId: "agent-1",
conversationId: "conv-1",
cwd: "../client",
},
socket as unknown as WebSocket,
runtime,
);
expect(
__listenClientTestUtils.getConversationWorkingDirectory(
runtime,
"agent-1",
"conv-1",
),
).toBe(normalizedClientDir);
expect(socket.sentPayloads).toHaveLength(2);
const changed = JSON.parse(socket.sentPayloads[0] as string);
expect(changed.type).toBe("cwd_changed");
expect(changed.success).toBe(true);
expect(changed.agent_id).toBe("agent-1");
expect(changed.cwd).toBe(normalizedClientDir);
expect(changed.conversation_id).toBe("conv-1");
const snapshot = JSON.parse(socket.sentPayloads[1] as string);
expect(snapshot.type).toBe("state_response");
expect(snapshot.configured_cwd).toBe(normalizedClientDir);
expect(snapshot.active_turn_cwd).toBe(normalizedServerDir);
expect(snapshot.cwd_agent_id).toBe("agent-1");
expect(snapshot.cwd_conversation_id).toBe("conv-1");
} finally {
await rm(tempRoot, { recursive: true, force: true });
}
}); });
}); });

View File

@@ -278,6 +278,7 @@ type ToolExecutionContextSnapshot = {
toolRegistry: ToolRegistry; toolRegistry: ToolRegistry;
externalTools: Map<string, ExternalToolDefinition>; externalTools: Map<string, ExternalToolDefinition>;
externalExecutor?: ExternalToolExecutor; externalExecutor?: ExternalToolExecutor;
workingDirectory: string;
}; };
export type CapturedToolExecutionContext = { export type CapturedToolExecutionContext = {
@@ -586,11 +587,14 @@ export function getClientToolsFromRegistry(): ClientTool[] {
* The returned context id can be used later to execute tool calls against this * The returned context id can be used later to execute tool calls against this
* exact snapshot even if the global registry changes between dispatch and execute. * exact snapshot even if the global registry changes between dispatch and execute.
*/ */
export function captureToolExecutionContext(): CapturedToolExecutionContext { export function captureToolExecutionContext(
workingDirectory: string = process.env.USER_CWD || process.cwd(),
): CapturedToolExecutionContext {
const snapshot: ToolExecutionContextSnapshot = { const snapshot: ToolExecutionContextSnapshot = {
toolRegistry: new Map(toolRegistry), toolRegistry: new Map(toolRegistry),
externalTools: new Map(getExternalToolsRegistry()), externalTools: new Map(getExternalToolsRegistry()),
externalExecutor: getExternalToolExecutor(), externalExecutor: getExternalToolExecutor(),
workingDirectory,
}; };
const contextId = saveExecutionContext(snapshot); const contextId = saveExecutionContext(snapshot);
@@ -615,6 +619,27 @@ export function captureToolExecutionContext(): CapturedToolExecutionContext {
}; };
} }
async function withExecutionWorkingDirectory<T>(
workingDirectory: string | undefined,
fn: () => Promise<T>,
): Promise<T> {
if (!workingDirectory) {
return fn();
}
const previousUserCwd = process.env.USER_CWD;
process.env.USER_CWD = workingDirectory;
try {
return await fn();
} finally {
if (previousUserCwd === undefined) {
delete process.env.USER_CWD;
} else {
process.env.USER_CWD = previousUserCwd;
}
}
}
/** /**
* Get permissions for a specific tool. * Get permissions for a specific tool.
* @param toolName - The name of the tool * @param toolName - The name of the tool
@@ -1158,6 +1183,7 @@ export async function executeTool(
context?.externalTools ?? getExternalToolsRegistry(); context?.externalTools ?? getExternalToolsRegistry();
const activeExternalExecutor = const activeExternalExecutor =
context?.externalExecutor ?? getExternalToolExecutor(); context?.externalExecutor ?? getExternalToolExecutor();
const workingDirectory = context?.workingDirectory;
// Check if this is an external tool (SDK-executed) // Check if this is an external tool (SDK-executed)
if (activeExternalTools.has(name)) { if (activeExternalTools.has(name)) {
@@ -1192,6 +1218,7 @@ export async function executeTool(
internalName, internalName,
args as Record<string, unknown>, args as Record<string, unknown>,
options?.toolCallId, options?.toolCallId,
workingDirectory,
); );
if (preHookResult.blocked) { if (preHookResult.blocked) {
const feedback = preHookResult.feedback.join("\n") || "Blocked by hook"; const feedback = preHookResult.feedback.join("\n") || "Blocked by hook";
@@ -1229,7 +1256,9 @@ export async function executeTool(
enhancedArgs = { ...enhancedArgs, toolCallId: options.toolCallId }; enhancedArgs = { ...enhancedArgs, toolCallId: options.toolCallId };
} }
const result = await tool.fn(enhancedArgs); const result = await withExecutionWorkingDirectory(workingDirectory, () =>
tool.fn(enhancedArgs),
);
const duration = Date.now() - startTime; const duration = Date.now() - startTime;
// Extract stdout/stderr if present (for bash tools) // Extract stdout/stderr if present (for bash tools)
@@ -1271,7 +1300,7 @@ export async function executeTool(
output: getDisplayableToolReturn(flattenedResponse), output: getDisplayableToolReturn(flattenedResponse),
}, },
options?.toolCallId, options?.toolCallId,
undefined, // workingDirectory workingDirectory,
undefined, // agentId undefined, // agentId
undefined, // precedingReasoning - not available in tool manager context undefined, // precedingReasoning - not available in tool manager context
undefined, // precedingAssistantMessage - not available in tool manager context undefined, // precedingAssistantMessage - not available in tool manager context
@@ -1295,7 +1324,7 @@ export async function executeTool(
errorOutput, errorOutput,
"tool_error", // error type for returned errors "tool_error", // error type for returned errors
options?.toolCallId, options?.toolCallId,
undefined, // workingDirectory workingDirectory,
undefined, // agentId undefined, // agentId
undefined, // precedingReasoning - not available in tool manager context undefined, // precedingReasoning - not available in tool manager context
undefined, // precedingAssistantMessage - not available in tool manager context undefined, // precedingAssistantMessage - not available in tool manager context
@@ -1378,7 +1407,7 @@ export async function executeTool(
args as Record<string, unknown>, args as Record<string, unknown>,
{ status: "error", output: errorMessage }, { status: "error", output: errorMessage },
options?.toolCallId, options?.toolCallId,
undefined, // workingDirectory workingDirectory,
undefined, // agentId undefined, // agentId
undefined, // precedingReasoning - not available in tool manager context undefined, // precedingReasoning - not available in tool manager context
undefined, // precedingAssistantMessage - not available in tool manager context undefined, // precedingAssistantMessage - not available in tool manager context
@@ -1397,7 +1426,7 @@ export async function executeTool(
errorMessage, errorMessage,
errorType, errorType,
options?.toolCallId, options?.toolCallId,
undefined, // workingDirectory workingDirectory,
undefined, // agentId undefined, // agentId
undefined, // precedingReasoning - not available in tool manager context undefined, // precedingReasoning - not available in tool manager context
undefined, // precedingAssistantMessage - not available in tool manager context undefined, // precedingAssistantMessage - not available in tool manager context

View File

@@ -3,6 +3,8 @@
* Connects to Letta Cloud and receives messages to execute locally * Connects to Letta Cloud and receives messages to execute locally
*/ */
import { realpath, stat } from "node:fs/promises";
import path from "node:path";
import { APIError } from "@letta-ai/letta-client/core/error"; import { APIError } from "@letta-ai/letta-client/core/error";
import type { Stream } from "@letta-ai/letta-client/core/streaming"; import type { Stream } from "@letta-ai/letta-client/core/streaming";
import type { MessageCreate } from "@letta-ai/letta-client/resources/agents/agents"; import type { MessageCreate } from "@letta-ai/letta-client/resources/agents/agents";
@@ -157,6 +159,15 @@ interface GetStatusMessage {
interface GetStateMessage { interface GetStateMessage {
type: "get_state"; type: "get_state";
agentId?: string | null;
conversationId?: string | null;
}
interface ChangeCwdMessage {
type: "change_cwd";
agentId?: string | null;
conversationId?: string | null;
cwd: string;
} }
interface CancelRunMessage { interface CancelRunMessage {
@@ -188,6 +199,10 @@ interface StateResponseMessage {
generated_at: string; generated_at: string;
state_seq: number; state_seq: number;
cwd: string; cwd: string;
configured_cwd: string;
active_turn_cwd: string | null;
cwd_agent_id: string | null;
cwd_conversation_id: string | null;
mode: "default" | "acceptEdits" | "plan" | "bypassPermissions"; mode: "default" | "acceptEdits" | "plan" | "bypassPermissions";
is_processing: boolean; is_processing: boolean;
last_stop_reason: string | null; last_stop_reason: string | null;
@@ -223,6 +238,17 @@ interface StateResponseMessage {
event_seq?: number; event_seq?: number;
} }
interface CwdChangedMessage {
type: "cwd_changed";
agent_id: string | null;
conversation_id: string;
cwd: string;
success: boolean;
error?: string;
event_seq?: number;
session_id?: string;
}
type ServerMessage = type ServerMessage =
| PongMessage | PongMessage
| StatusMessage | StatusMessage
@@ -230,6 +256,7 @@ type ServerMessage =
| ModeChangeMessage | ModeChangeMessage
| GetStatusMessage | GetStatusMessage
| GetStateMessage | GetStateMessage
| ChangeCwdMessage
| CancelRunMessage | CancelRunMessage
| RecoverPendingApprovalsMessage | RecoverPendingApprovalsMessage
| WsControlResponse; | WsControlResponse;
@@ -238,6 +265,7 @@ type ClientMessage =
| RunStartedMessage | RunStartedMessage
| RunRequestErrorMessage | RunRequestErrorMessage
| ModeChangedMessage | ModeChangedMessage
| CwdChangedMessage
| StatusResponseMessage | StatusResponseMessage
| StateResponseMessage; | StateResponseMessage;
@@ -266,6 +294,7 @@ type ListenerRuntime = {
/** Active run metadata for reconnect snapshot state. */ /** Active run metadata for reconnect snapshot state. */
activeAgentId: string | null; activeAgentId: string | null;
activeConversationId: string | null; activeConversationId: string | null;
activeWorkingDirectory: string | null;
activeRunId: string | null; activeRunId: string | null;
activeRunStartedAt: string | null; activeRunStartedAt: string | null;
/** Abort controller for the currently active message turn. */ /** Abort controller for the currently active message turn. */
@@ -311,6 +340,8 @@ type ListenerRuntime = {
* Threaded into the next send for persistence normalization. * Threaded into the next send for persistence normalization.
*/ */
pendingInterruptedToolCallIds: string[] | null; pendingInterruptedToolCallIds: string[] | null;
bootWorkingDirectory: string;
workingDirectoryByConversation: Map<string, string>;
}; };
// Listen mode supports one active connection per process. // Listen mode supports one active connection per process.
@@ -354,11 +385,94 @@ function handleModeChange(msg: ModeChangeMessage, socket: WebSocket): void {
} }
} }
function normalizeCwdAgentId(agentId?: string | null): string | null {
return agentId && agentId.length > 0 ? agentId : null;
}
function getWorkingDirectoryScopeKey(
agentId?: string | null,
conversationId?: string | null,
): string {
const normalizedConversationId = normalizeConversationId(conversationId);
const normalizedAgentId = normalizeCwdAgentId(agentId);
if (normalizedConversationId === "default") {
return `agent:${normalizedAgentId ?? "__unknown__"}::conversation:default`;
}
return `conversation:${normalizedConversationId}`;
}
async function handleCwdChange(
msg: ChangeCwdMessage,
socket: WebSocket,
runtime: ListenerRuntime,
): Promise<void> {
const conversationId = normalizeConversationId(msg.conversationId);
const agentId = normalizeCwdAgentId(msg.agentId);
const currentWorkingDirectory = getConversationWorkingDirectory(
runtime,
agentId,
conversationId,
);
try {
const requestedPath = msg.cwd?.trim();
if (!requestedPath) {
throw new Error("Working directory cannot be empty");
}
const resolvedPath = path.isAbsolute(requestedPath)
? requestedPath
: path.resolve(currentWorkingDirectory, requestedPath);
const normalizedPath = await realpath(resolvedPath);
const stats = await stat(normalizedPath);
if (!stats.isDirectory()) {
throw new Error(`Not a directory: ${normalizedPath}`);
}
setConversationWorkingDirectory(
runtime,
agentId,
conversationId,
normalizedPath,
);
sendClientMessage(
socket,
{
type: "cwd_changed",
agent_id: agentId,
conversation_id: conversationId,
cwd: normalizedPath,
success: true,
},
runtime,
);
sendStateSnapshot(socket, runtime, agentId, conversationId);
} catch (error) {
sendClientMessage(
socket,
{
type: "cwd_changed",
agent_id: agentId,
conversation_id: conversationId,
cwd: msg.cwd,
success: false,
error:
error instanceof Error
? error.message
: "Working directory change failed",
},
runtime,
);
}
}
const MAX_RETRY_DURATION_MS = 5 * 60 * 1000; // 5 minutes const MAX_RETRY_DURATION_MS = 5 * 60 * 1000; // 5 minutes
const INITIAL_RETRY_DELAY_MS = 1000; // 1 second const INITIAL_RETRY_DELAY_MS = 1000; // 1 second
const MAX_RETRY_DELAY_MS = 30000; // 30 seconds const MAX_RETRY_DELAY_MS = 30000; // 30 seconds
function createRuntime(): ListenerRuntime { function createRuntime(): ListenerRuntime {
const bootWorkingDirectory = process.env.USER_CWD || process.cwd();
const runtime: ListenerRuntime = { const runtime: ListenerRuntime = {
socket: null, socket: null,
heartbeatInterval: null, heartbeatInterval: null,
@@ -373,6 +487,7 @@ function createRuntime(): ListenerRuntime {
isProcessing: false, isProcessing: false,
activeAgentId: null, activeAgentId: null,
activeConversationId: null, activeConversationId: null,
activeWorkingDirectory: null,
activeRunId: null, activeRunId: null,
activeRunStartedAt: null, activeRunStartedAt: null,
activeAbortController: null, activeAbortController: null,
@@ -384,6 +499,8 @@ function createRuntime(): ListenerRuntime {
continuationEpoch: 0, continuationEpoch: 0,
activeExecutingToolCallIds: [], activeExecutingToolCallIds: [],
pendingInterruptedToolCallIds: null, pendingInterruptedToolCallIds: null,
bootWorkingDirectory,
workingDirectoryByConversation: new Map<string, string>(),
coalescedSkipQueueItemIds: new Set<string>(), coalescedSkipQueueItemIds: new Set<string>(),
pendingTurns: 0, pendingTurns: 0,
// queueRuntime assigned below — needs runtime ref in callbacks // queueRuntime assigned below — needs runtime ref in callbacks
@@ -462,6 +579,39 @@ function createRuntime(): ListenerRuntime {
return runtime; return runtime;
} }
function normalizeConversationId(conversationId?: string | null): string {
return conversationId && conversationId.length > 0
? conversationId
: "default";
}
function getConversationWorkingDirectory(
runtime: ListenerRuntime,
agentId?: string | null,
conversationId?: string | null,
): string {
const scopeKey = getWorkingDirectoryScopeKey(agentId, conversationId);
return (
runtime.workingDirectoryByConversation.get(scopeKey) ??
runtime.bootWorkingDirectory
);
}
function setConversationWorkingDirectory(
runtime: ListenerRuntime,
agentId: string | null,
conversationId: string,
workingDirectory: string,
): void {
const scopeKey = getWorkingDirectoryScopeKey(agentId, conversationId);
if (workingDirectory === runtime.bootWorkingDirectory) {
runtime.workingDirectoryByConversation.delete(scopeKey);
return;
}
runtime.workingDirectoryByConversation.set(scopeKey, workingDirectory);
}
function clearRuntimeTimers(runtime: ListenerRuntime): void { function clearRuntimeTimers(runtime: ListenerRuntime): void {
if (runtime.reconnectTimeout) { if (runtime.reconnectTimeout) {
clearTimeout(runtime.reconnectTimeout); clearTimeout(runtime.reconnectTimeout);
@@ -476,6 +626,7 @@ function clearRuntimeTimers(runtime: ListenerRuntime): void {
function clearActiveRunState(runtime: ListenerRuntime): void { function clearActiveRunState(runtime: ListenerRuntime): void {
runtime.activeAgentId = null; runtime.activeAgentId = null;
runtime.activeConversationId = null; runtime.activeConversationId = null;
runtime.activeWorkingDirectory = null;
runtime.activeRunId = null; runtime.activeRunId = null;
runtime.activeRunStartedAt = null; runtime.activeRunStartedAt = null;
runtime.activeAbortController = null; runtime.activeAbortController = null;
@@ -615,6 +766,7 @@ export function parseServerMessage(
parsed.type === "mode_change" || parsed.type === "mode_change" ||
parsed.type === "get_status" || parsed.type === "get_status" ||
parsed.type === "get_state" || parsed.type === "get_state" ||
parsed.type === "change_cwd" ||
parsed.type === "cancel_run" || parsed.type === "cancel_run" ||
parsed.type === "recover_pending_approvals" parsed.type === "recover_pending_approvals"
) { ) {
@@ -692,7 +844,21 @@ function mergeDequeuedBatchContent(
function buildStateResponse( function buildStateResponse(
runtime: ListenerRuntime, runtime: ListenerRuntime,
stateSeq: number, stateSeq: number,
agentId?: string | null,
conversationId?: string | null,
): StateResponseMessage { ): StateResponseMessage {
const scopedAgentId = normalizeCwdAgentId(agentId);
const scopedConversationId = normalizeConversationId(conversationId);
const configuredWorkingDirectory = getConversationWorkingDirectory(
runtime,
scopedAgentId,
scopedConversationId,
);
const activeTurnWorkingDirectory =
runtime.activeAgentId === scopedAgentId &&
runtime.activeConversationId === scopedConversationId
? runtime.activeWorkingDirectory
: null;
const queueItems = runtime.queueRuntime.items.map((item) => ({ const queueItems = runtime.queueRuntime.items.map((item) => ({
id: item.id, id: item.id,
client_message_id: item.clientMessageId ?? `cm-${item.id}`, client_message_id: item.clientMessageId ?? `cm-${item.id}`,
@@ -724,7 +890,11 @@ function buildStateResponse(
generated_at: new Date().toISOString(), generated_at: new Date().toISOString(),
state_seq: stateSeq, state_seq: stateSeq,
event_seq: stateSeq, event_seq: stateSeq,
cwd: process.env.USER_CWD || process.cwd(), cwd: configuredWorkingDirectory,
configured_cwd: configuredWorkingDirectory,
active_turn_cwd: activeTurnWorkingDirectory,
cwd_agent_id: scopedAgentId,
cwd_conversation_id: scopedConversationId,
mode: permissionMode.getMode(), mode: permissionMode.getMode(),
is_processing: runtime.isProcessing, is_processing: runtime.isProcessing,
last_stop_reason: runtime.lastStopReason, last_stop_reason: runtime.lastStopReason,
@@ -745,12 +915,22 @@ function buildStateResponse(
}; };
} }
function sendStateSnapshot(socket: WebSocket, runtime: ListenerRuntime): void { function sendStateSnapshot(
socket: WebSocket,
runtime: ListenerRuntime,
agentId?: string | null,
conversationId?: string | null,
): void {
const stateSeq = nextEventSeq(runtime); const stateSeq = nextEventSeq(runtime);
if (stateSeq === null) { if (stateSeq === null) {
return; return;
} }
const stateResponse = buildStateResponse(runtime, stateSeq); const stateResponse = buildStateResponse(
runtime,
stateSeq,
agentId,
conversationId,
);
sendClientMessage(socket, stateResponse, runtime); sendClientMessage(socket, stateResponse, runtime);
} }
@@ -1508,6 +1688,13 @@ async function resolveStaleApprovals(
agentId: runtime.activeAgentId, agentId: runtime.activeAgentId,
streamTokens: true, streamTokens: true,
background: true, background: true,
workingDirectory:
runtime.activeWorkingDirectory ??
getConversationWorkingDirectory(
runtime,
runtime.activeAgentId,
recoveryConversationId,
),
}, },
{ maxRetries: 0, signal: abortSignal }, { maxRetries: 0, signal: abortSignal },
); );
@@ -1751,6 +1938,17 @@ async function recoverPendingApprovals(
const requestedConversationId = msg.conversationId || undefined; const requestedConversationId = msg.conversationId || undefined;
const conversationId = requestedConversationId ?? "default"; const conversationId = requestedConversationId ?? "default";
const recoveryAgentId = normalizeCwdAgentId(agentId);
const recoveryWorkingDirectory =
runtime.activeAgentId === recoveryAgentId &&
runtime.activeConversationId === conversationId &&
runtime.activeWorkingDirectory
? runtime.activeWorkingDirectory
: getConversationWorkingDirectory(
runtime,
recoveryAgentId,
conversationId,
);
const client = await getClient(); const client = await getClient();
const agent = await client.agents.retrieve(agentId); const agent = await client.agents.retrieve(agentId);
@@ -1814,6 +2012,7 @@ async function recoverPendingApprovals(
alwaysRequiresUserInput: isInteractiveApprovalTool, alwaysRequiresUserInput: isInteractiveApprovalTool,
treatAskAsDeny: false, treatAskAsDeny: false,
requireArgsForAutoApprove: true, requireArgsForAutoApprove: true,
workingDirectory: recoveryWorkingDirectory,
}, },
); );
@@ -1858,6 +2057,7 @@ async function recoverPendingApprovals(
const diffs = await computeDiffPreviews( const diffs = await computeDiffPreviews(
ac.approval.toolName, ac.approval.toolName,
ac.parsedArgs, ac.parsedArgs,
recoveryWorkingDirectory,
); );
const controlRequest: ControlRequest = { const controlRequest: ControlRequest = {
@@ -1931,7 +2131,9 @@ async function recoverPendingApprovals(
return; return;
} }
const executionResults = await executeApprovalBatch(decisions); const executionResults = await executeApprovalBatch(decisions, undefined, {
workingDirectory: recoveryWorkingDirectory,
});
clearPendingApprovalBatchIds( clearPendingApprovalBatchIds(
runtime, runtime,
decisions.map((decision) => decision.approval), decisions.map((decision) => decision.approval),
@@ -2115,6 +2317,15 @@ async function connectWithRetry(
return; return;
} }
if (parsed.type === "change_cwd") {
if (runtime !== activeRuntime || runtime.intentionallyClosed) {
return;
}
void handleCwdChange(parsed, socket, runtime);
return;
}
// Handle status request from cloud (immediate response) // Handle status request from cloud (immediate response)
if (parsed.type === "get_status") { if (parsed.type === "get_status") {
if (runtime !== activeRuntime || runtime.intentionallyClosed) { if (runtime !== activeRuntime || runtime.intentionallyClosed) {
@@ -2219,12 +2430,21 @@ async function connectWithRetry(
if (runtime !== activeRuntime || runtime.intentionallyClosed) { if (runtime !== activeRuntime || runtime.intentionallyClosed) {
return; return;
} }
const requestedConversationId = normalizeConversationId(
parsed.conversationId,
);
const requestedAgentId = normalizeCwdAgentId(parsed.agentId);
// If we're blocked on an approval callback, don't queue behind the // If we're blocked on an approval callback, don't queue behind the
// pending turn; respond immediately so refreshed clients can render the // pending turn; respond immediately so refreshed clients can render the
// approval card needed to unblock execution. // approval card needed to unblock execution.
if (runtime.pendingApprovalResolvers.size > 0) { if (runtime.pendingApprovalResolvers.size > 0) {
sendStateSnapshot(socket, runtime); sendStateSnapshot(
socket,
runtime,
requestedAgentId,
requestedConversationId,
);
return; return;
} }
@@ -2236,7 +2456,12 @@ async function connectWithRetry(
return; return;
} }
sendStateSnapshot(socket, runtime); sendStateSnapshot(
socket,
runtime,
requestedAgentId,
requestedConversationId,
);
}) })
.catch((error: unknown) => { .catch((error: unknown) => {
if (process.env.DEBUG) { if (process.env.DEBUG) {
@@ -2507,6 +2732,12 @@ async function handleIncomingMessage(
const agentId = msg.agentId; const agentId = msg.agentId;
const requestedConversationId = msg.conversationId || undefined; const requestedConversationId = msg.conversationId || undefined;
const conversationId = requestedConversationId ?? "default"; const conversationId = requestedConversationId ?? "default";
const normalizedAgentId = normalizeCwdAgentId(agentId);
const turnWorkingDirectory = getConversationWorkingDirectory(
runtime,
normalizedAgentId,
conversationId,
);
const msgStartTime = performance.now(); const msgStartTime = performance.now();
let msgTurnCount = 0; let msgTurnCount = 0;
const msgRunIds: string[] = []; const msgRunIds: string[] = [];
@@ -2523,6 +2754,7 @@ async function handleIncomingMessage(
runtime.activeAbortController = new AbortController(); runtime.activeAbortController = new AbortController();
runtime.activeAgentId = agentId ?? null; runtime.activeAgentId = agentId ?? null;
runtime.activeConversationId = conversationId; runtime.activeConversationId = conversationId;
runtime.activeWorkingDirectory = turnWorkingDirectory;
runtime.activeRunId = null; runtime.activeRunId = null;
runtime.activeRunStartedAt = new Date().toISOString(); runtime.activeRunStartedAt = new Date().toISOString();
runtime.activeExecutingToolCallIds = []; runtime.activeExecutingToolCallIds = [];
@@ -2566,6 +2798,7 @@ async function handleIncomingMessage(
agentId, agentId,
streamTokens: true, streamTokens: true,
background: true, background: true,
workingDirectory: turnWorkingDirectory,
...(queuedInterruptedToolCallIds.length > 0 ...(queuedInterruptedToolCallIds.length > 0
? { ? {
approvalNormalization: { approvalNormalization: {
@@ -2847,6 +3080,7 @@ async function handleIncomingMessage(
alwaysRequiresUserInput: isInteractiveApprovalTool, alwaysRequiresUserInput: isInteractiveApprovalTool,
treatAskAsDeny: false, // Let cloud UI handle approvals treatAskAsDeny: false, // Let cloud UI handle approvals
requireArgsForAutoApprove: true, requireArgsForAutoApprove: true,
workingDirectory: turnWorkingDirectory,
}); });
// Snapshot all tool_call_ids before entering approval wait so cancel can // Snapshot all tool_call_ids before entering approval wait so cancel can
@@ -2917,6 +3151,7 @@ async function handleIncomingMessage(
const diffs = await computeDiffPreviews( const diffs = await computeDiffPreviews(
ac.approval.toolName, ac.approval.toolName,
ac.parsedArgs, ac.parsedArgs,
turnWorkingDirectory,
); );
const controlRequest: ControlRequest = { const controlRequest: ControlRequest = {
@@ -3003,6 +3238,7 @@ async function handleIncomingMessage(
{ {
toolContextId: turnToolContextId ?? undefined, toolContextId: turnToolContextId ?? undefined,
abortSignal: runtime.activeAbortController.signal, abortSignal: runtime.activeAbortController.signal,
workingDirectory: turnWorkingDirectory,
}, },
); );
const persistedExecutionResults = const persistedExecutionResults =
@@ -3172,12 +3408,15 @@ export const __listenClientTestUtils = {
createRuntime, createRuntime,
stopRuntime, stopRuntime,
buildStateResponse, buildStateResponse,
handleCwdChange,
emitToWS, emitToWS,
getConversationWorkingDirectory,
rememberPendingApprovalBatchIds, rememberPendingApprovalBatchIds,
resolvePendingApprovalBatchId, resolvePendingApprovalBatchId,
resolveRecoveryBatchId, resolveRecoveryBatchId,
clearPendingApprovalBatchIds, clearPendingApprovalBatchIds,
populateInterruptQueue, populateInterruptQueue,
setConversationWorkingDirectory,
consumeInterruptQueue, consumeInterruptQueue,
extractInterruptToolReturns, extractInterruptToolReturns,
emitInterruptToolReturnMessage, emitInterruptToolReturnMessage,