feat: add override_model support for agent file import (#9058)
This commit is contained in:
committed by
Caren Thomas
parent
82c01368fc
commit
1d1bb29a43
@@ -27960,6 +27960,18 @@
|
||||
"title": "Embedding",
|
||||
"description": "Embedding handle to override with."
|
||||
},
|
||||
"model": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Model",
|
||||
"description": "Model handle to override the agent's default model. This allows the imported agent to use a different model while keeping other defaults (e.g., context size) from the original configuration."
|
||||
},
|
||||
"append_copy_suffix": {
|
||||
"type": "boolean",
|
||||
"title": "Append Copy Suffix",
|
||||
@@ -27993,6 +28005,19 @@
|
||||
"description": "Override import with specific embedding handle. Use 'embedding' instead.",
|
||||
"deprecated": true
|
||||
},
|
||||
"override_model_handle": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Override Model Handle",
|
||||
"description": "Model handle to override the agent's default model. Use 'model' instead.",
|
||||
"deprecated": true
|
||||
},
|
||||
"project_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
|
||||
@@ -308,6 +308,7 @@ async def _import_agent(
|
||||
strip_messages: bool = False,
|
||||
env_vars: Optional[dict[str, Any]] = None,
|
||||
override_embedding_handle: Optional[str] = None,
|
||||
override_model_handle: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Import an agent using the new AgentFileSchema format.
|
||||
@@ -319,6 +320,11 @@ async def _import_agent(
|
||||
else:
|
||||
embedding_config_override = None
|
||||
|
||||
if override_model_handle:
|
||||
llm_config_override = await server.get_llm_config_from_handle_async(actor=actor, handle=override_model_handle)
|
||||
else:
|
||||
llm_config_override = None
|
||||
|
||||
import_result = await server.agent_serialization_manager.import_file(
|
||||
schema=agent_schema,
|
||||
actor=actor,
|
||||
@@ -327,6 +333,7 @@ async def _import_agent(
|
||||
override_existing_tools=override_existing_tools,
|
||||
env_vars=env_vars,
|
||||
override_embedding_config=embedding_config_override,
|
||||
override_llm_config=llm_config_override,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
@@ -362,6 +369,10 @@ async def import_agent(
|
||||
None,
|
||||
description="Embedding handle to override with.",
|
||||
),
|
||||
model: Optional[str] = Form(
|
||||
None,
|
||||
description="Model handle to override the agent's default model. This allows the imported agent to use a different model while keeping other defaults (e.g., context size) from the original configuration.",
|
||||
),
|
||||
# Deprecated fields (maintain backward compatibility)
|
||||
append_copy_suffix: bool = Form(
|
||||
True,
|
||||
@@ -378,6 +389,11 @@ async def import_agent(
|
||||
description="Override import with specific embedding handle. Use 'embedding' instead.",
|
||||
deprecated=True,
|
||||
),
|
||||
override_model_handle: Optional[str] = Form(
|
||||
None,
|
||||
description="Model handle to override the agent's default model. Use 'model' instead.",
|
||||
deprecated=True,
|
||||
),
|
||||
project_id: str | None = Form(
|
||||
None, description="The project ID to associate the uploaded agent with. This is now passed via headers.", deprecated=True
|
||||
),
|
||||
@@ -408,6 +424,7 @@ async def import_agent(
|
||||
# Handle backward compatibility: prefer new field names over deprecated ones
|
||||
final_name = name or override_name
|
||||
final_embedding_handle = embedding or override_embedding_handle or x_override_embedding_model
|
||||
final_model_handle = model or override_model_handle
|
||||
|
||||
# Parse secrets (new) or env_vars_json (deprecated)
|
||||
env_vars = None
|
||||
@@ -440,6 +457,7 @@ async def import_agent(
|
||||
strip_messages=strip_messages,
|
||||
env_vars=env_vars,
|
||||
override_embedding_handle=final_embedding_handle,
|
||||
override_model_handle=final_model_handle,
|
||||
)
|
||||
else:
|
||||
# This is a legacy AgentSchema
|
||||
|
||||
@@ -30,6 +30,7 @@ from letta.schemas.agent_file import (
|
||||
)
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.group import Group, GroupCreate
|
||||
@@ -472,6 +473,7 @@ class AgentSerializationManager:
|
||||
dry_run: bool = False,
|
||||
env_vars: Optional[Dict[str, Any]] = None,
|
||||
override_embedding_config: Optional[EmbeddingConfig] = None,
|
||||
override_llm_config: Optional[LLMConfig] = None,
|
||||
project_id: Optional[str] = None,
|
||||
) -> ImportResult:
|
||||
"""
|
||||
@@ -672,6 +674,11 @@ class AgentSerializationManager:
|
||||
agent_schema.embedding_config = override_embedding_config
|
||||
agent_schema.embedding = override_embedding_config.handle
|
||||
|
||||
# Override llm_config if provided (keeps other defaults like context size)
|
||||
if override_llm_config:
|
||||
agent_schema.llm_config = override_llm_config
|
||||
agent_schema.model = override_llm_config.handle
|
||||
|
||||
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
|
||||
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})
|
||||
|
||||
|
||||
@@ -235,6 +235,15 @@ def embedding_handle_override():
|
||||
return "openai/text-embedding-ada-002"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def model_handle_override():
|
||||
# Use a different OpenAI model handle for override tests.
|
||||
# The default in tests is usually gpt-4o-mini, so we use gpt-4o.
|
||||
current_handle = LLMConfig.default_config("gpt-4o-mini").handle or "openai/gpt-4o-mini"
|
||||
assert current_handle != "openai/gpt-4o" # make sure it's different
|
||||
return "openai/gpt-4o"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def test_source(server: SyncServer, default_user):
|
||||
"""Fixture to create and return a test source."""
|
||||
@@ -1166,6 +1175,52 @@ class TestAgentFileImport:
|
||||
imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user)
|
||||
assert imported_agent.embedding_config.handle == embedding_handle_override
|
||||
|
||||
async def test_basic_import_with_model_override(
|
||||
self, server, agent_serialization_manager, test_agent, default_user, other_user, model_handle_override
|
||||
):
|
||||
"""Test basic agent import functionality with LLM model override."""
|
||||
# Verify original agent has gpt-4o-mini (handle may be None for legacy configs)
|
||||
assert "gpt-4o-mini" in (test_agent.llm_config.handle or "") or "gpt-4o-mini" in (test_agent.llm_config.model or "")
|
||||
|
||||
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
|
||||
|
||||
llm_config_override = await server.get_llm_config_from_handle_async(actor=other_user, handle=model_handle_override)
|
||||
result = await agent_serialization_manager.import_file(agent_file, other_user, override_llm_config=llm_config_override)
|
||||
|
||||
assert result.success
|
||||
assert result.imported_count > 0
|
||||
assert len(result.id_mappings) > 0
|
||||
|
||||
for file_id, db_id in result.id_mappings.items():
|
||||
if file_id.startswith("agent-"):
|
||||
assert db_id != test_agent.id # New agent should have different ID
|
||||
|
||||
# check model handle was overridden
|
||||
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
|
||||
imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user)
|
||||
assert imported_agent.llm_config.handle == model_handle_override
|
||||
|
||||
async def test_basic_import_with_both_overrides(
|
||||
self, server, agent_serialization_manager, test_agent, default_user, other_user, embedding_handle_override, model_handle_override
|
||||
):
|
||||
"""Test agent import with both embedding and model overrides."""
|
||||
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
|
||||
|
||||
embedding_config_override = await server.get_embedding_config_from_handle_async(actor=other_user, handle=embedding_handle_override)
|
||||
llm_config_override = await server.get_llm_config_from_handle_async(actor=other_user, handle=model_handle_override)
|
||||
result = await agent_serialization_manager.import_file(
|
||||
agent_file, other_user, override_embedding_config=embedding_config_override, override_llm_config=llm_config_override
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.imported_count > 0
|
||||
|
||||
# Verify both overrides were applied
|
||||
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
|
||||
imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user)
|
||||
assert imported_agent.embedding_config.handle == embedding_handle_override
|
||||
assert imported_agent.llm_config.handle == model_handle_override
|
||||
|
||||
async def test_import_preserves_data(self, server, agent_serialization_manager, test_agent, default_user, other_user):
|
||||
"""Test that import preserves all important data."""
|
||||
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
|
||||
|
||||
Reference in New Issue
Block a user