feat: Model based toolset switching (#111)
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
// Utilities for modifying agent configuration
|
||||
|
||||
import type { LlmConfig } from "@letta-ai/letta-client/resources/models/models";
|
||||
import { getToolNames } from "../tools/manager";
|
||||
import { getAllLettaToolNames, getToolNames } from "../tools/manager";
|
||||
import { getClient } from "./client";
|
||||
|
||||
/**
|
||||
@@ -19,35 +19,42 @@ import { getClient } from "./client";
|
||||
*/
|
||||
export async function updateAgentLLMConfig(
|
||||
agentId: string,
|
||||
_modelHandle: string,
|
||||
modelHandle: string,
|
||||
updateArgs?: Record<string, unknown>,
|
||||
preserveParallelToolCalls?: boolean,
|
||||
): Promise<LlmConfig> {
|
||||
const client = await getClient();
|
||||
|
||||
// Get current agent to preserve parallel_tool_calls if requested
|
||||
// Step 1: change model (preserve parallel_tool_calls if requested)
|
||||
const currentAgent = await client.agents.retrieve(agentId);
|
||||
const originalParallelToolCalls = preserveParallelToolCalls
|
||||
? (currentAgent.llm_config?.parallel_tool_calls ?? undefined)
|
||||
const currentParallel = preserveParallelToolCalls
|
||||
? currentAgent.llm_config?.parallel_tool_calls
|
||||
: undefined;
|
||||
|
||||
// Strategy: Do everything in ONE modify call via llm_config
|
||||
// This avoids the backend resetting parallel_tool_calls when we update the model
|
||||
const updatedLlmConfig = {
|
||||
...currentAgent.llm_config,
|
||||
...updateArgs,
|
||||
// Explicitly preserve parallel_tool_calls
|
||||
...(originalParallelToolCalls !== undefined && {
|
||||
parallel_tool_calls: originalParallelToolCalls,
|
||||
}),
|
||||
} as LlmConfig;
|
||||
|
||||
await client.agents.modify(agentId, {
|
||||
llm_config: updatedLlmConfig,
|
||||
parallel_tool_calls: originalParallelToolCalls,
|
||||
await client.agents.update(agentId, {
|
||||
model: modelHandle,
|
||||
parallel_tool_calls: currentParallel,
|
||||
});
|
||||
|
||||
// Retrieve and return final state
|
||||
// Step 2: if there are llm_config overrides, apply them using fresh state
|
||||
if (updateArgs && Object.keys(updateArgs).length > 0) {
|
||||
const refreshed = await client.agents.retrieve(agentId);
|
||||
const refreshedConfig = (refreshed.llm_config || {}) as LlmConfig;
|
||||
|
||||
const mergedLlmConfig: LlmConfig = {
|
||||
...refreshedConfig,
|
||||
...(updateArgs as Record<string, unknown>),
|
||||
...(currentParallel !== undefined && {
|
||||
parallel_tool_calls: currentParallel,
|
||||
}),
|
||||
} as LlmConfig;
|
||||
|
||||
await client.agents.update(agentId, {
|
||||
llm_config: mergedLlmConfig,
|
||||
parallel_tool_calls: currentParallel,
|
||||
});
|
||||
}
|
||||
|
||||
const finalAgent = await client.agents.retrieve(agentId);
|
||||
return finalAgent.llm_config;
|
||||
}
|
||||
@@ -75,7 +82,9 @@ export async function linkToolsToAgent(agentId: string): Promise<LinkResult> {
|
||||
const client = await getClient();
|
||||
|
||||
// Get ALL agent tools from agent state
|
||||
const agent = await client.agents.retrieve(agentId);
|
||||
const agent = await client.agents.retrieve(agentId, {
|
||||
include: ["agent.tools"],
|
||||
});
|
||||
const currentTools = agent.tools || [];
|
||||
const currentToolIds = currentTools
|
||||
.map((t) => t.id)
|
||||
@@ -105,8 +114,8 @@ export async function linkToolsToAgent(agentId: string): Promise<LinkResult> {
|
||||
// Look up tool IDs from global tool list
|
||||
const toolsToAddIds: string[] = [];
|
||||
for (const toolName of toolsToAdd) {
|
||||
const tools = await client.tools.list({ name: toolName });
|
||||
const tool = tools[0];
|
||||
const toolsResponse = await client.tools.list({ name: toolName });
|
||||
const tool = toolsResponse.items[0];
|
||||
if (tool?.id) {
|
||||
toolsToAddIds.push(tool.id);
|
||||
}
|
||||
@@ -126,7 +135,7 @@ export async function linkToolsToAgent(agentId: string): Promise<LinkResult> {
|
||||
})),
|
||||
];
|
||||
|
||||
await client.agents.modify(agentId, {
|
||||
await client.agents.update(agentId, {
|
||||
tool_ids: newToolIds,
|
||||
tool_rules: newToolRules,
|
||||
});
|
||||
@@ -157,9 +166,11 @@ export async function unlinkToolsFromAgent(
|
||||
const client = await getClient();
|
||||
|
||||
// Get ALL agent tools from agent state (not tools.list which may be incomplete)
|
||||
const agent = await client.agents.retrieve(agentId);
|
||||
const agent = await client.agents.retrieve(agentId, {
|
||||
include: ["agent.tools"],
|
||||
});
|
||||
const allTools = agent.tools || [];
|
||||
const lettaCodeToolNames = new Set(getToolNames());
|
||||
const lettaCodeToolNames = new Set(getAllLettaToolNames());
|
||||
|
||||
// Filter out Letta Code tools, keep everything else
|
||||
const remainingTools = allTools.filter(
|
||||
@@ -180,7 +191,7 @@ export async function unlinkToolsFromAgent(
|
||||
!lettaCodeToolNames.has(rule.tool_name),
|
||||
);
|
||||
|
||||
await client.agents.modify(agentId, {
|
||||
await client.agents.update(agentId, {
|
||||
tool_ids: remainingToolIds,
|
||||
tool_rules: remainingToolRules,
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user