From 1d1bb29a43515531bb763fa4d8332cba15b5fb25 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:45:24 -0800 Subject: [PATCH] feat: add override_model support for agent file import (#9058) --- fern/openapi.json | 25 +++++++++ letta/server/rest_api/routers/v1/agents.py | 18 ++++++ letta/services/agent_serialization_manager.py | 7 +++ tests/test_agent_serialization_v2.py | 55 +++++++++++++++++++ 4 files changed, 105 insertions(+) diff --git a/fern/openapi.json b/fern/openapi.json index 4e3fead5..d52bcaf6 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -27960,6 +27960,18 @@ "title": "Embedding", "description": "Embedding handle to override with." }, + "model": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Model", + "description": "Model handle to override the agent's default model. This allows the imported agent to use a different model while keeping other defaults (e.g., context size) from the original configuration." + }, "append_copy_suffix": { "type": "boolean", "title": "Append Copy Suffix", @@ -27993,6 +28005,19 @@ "description": "Override import with specific embedding handle. Use 'embedding' instead.", "deprecated": true }, + "override_model_handle": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Override Model Handle", + "description": "Model handle to override the agent's default model. Use 'model' instead.", + "deprecated": true + }, "project_id": { "anyOf": [ { diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index c5ee4ed5..35a2b898 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -308,6 +308,7 @@ async def _import_agent( strip_messages: bool = False, env_vars: Optional[dict[str, Any]] = None, override_embedding_handle: Optional[str] = None, + override_model_handle: Optional[str] = None, ) -> List[str]: """ Import an agent using the new AgentFileSchema format. @@ -319,6 +320,11 @@ async def _import_agent( else: embedding_config_override = None + if override_model_handle: + llm_config_override = await server.get_llm_config_from_handle_async(actor=actor, handle=override_model_handle) + else: + llm_config_override = None + import_result = await server.agent_serialization_manager.import_file( schema=agent_schema, actor=actor, @@ -327,6 +333,7 @@ async def _import_agent( override_existing_tools=override_existing_tools, env_vars=env_vars, override_embedding_config=embedding_config_override, + override_llm_config=llm_config_override, project_id=project_id, ) @@ -362,6 +369,10 @@ async def import_agent( None, description="Embedding handle to override with.", ), + model: Optional[str] = Form( + None, + description="Model handle to override the agent's default model. This allows the imported agent to use a different model while keeping other defaults (e.g., context size) from the original configuration.", + ), # Deprecated fields (maintain backward compatibility) append_copy_suffix: bool = Form( True, @@ -378,6 +389,11 @@ async def import_agent( description="Override import with specific embedding handle. Use 'embedding' instead.", deprecated=True, ), + override_model_handle: Optional[str] = Form( + None, + description="Model handle to override the agent's default model. Use 'model' instead.", + deprecated=True, + ), project_id: str | None = Form( None, description="The project ID to associate the uploaded agent with. This is now passed via headers.", deprecated=True ), @@ -408,6 +424,7 @@ async def import_agent( # Handle backward compatibility: prefer new field names over deprecated ones final_name = name or override_name final_embedding_handle = embedding or override_embedding_handle or x_override_embedding_model + final_model_handle = model or override_model_handle # Parse secrets (new) or env_vars_json (deprecated) env_vars = None @@ -440,6 +457,7 @@ async def import_agent( strip_messages=strip_messages, env_vars=env_vars, override_embedding_handle=final_embedding_handle, + override_model_handle=final_model_handle, ) else: # This is a legacy AgentSchema diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index eb58d022..46f39ca1 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -30,6 +30,7 @@ from letta.schemas.agent_file import ( ) from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig from letta.schemas.enums import FileProcessingStatus, VectorDBProvider from letta.schemas.file import FileMetadata from letta.schemas.group import Group, GroupCreate @@ -472,6 +473,7 @@ class AgentSerializationManager: dry_run: bool = False, env_vars: Optional[Dict[str, Any]] = None, override_embedding_config: Optional[EmbeddingConfig] = None, + override_llm_config: Optional[LLMConfig] = None, project_id: Optional[str] = None, ) -> ImportResult: """ @@ -672,6 +674,11 @@ class AgentSerializationManager: agent_schema.embedding_config = override_embedding_config agent_schema.embedding = override_embedding_config.handle + # Override llm_config if provided (keeps other defaults like context size) + if override_llm_config: + agent_schema.llm_config = override_llm_config + agent_schema.model = override_llm_config.handle + # Convert AgentSchema back to CreateAgent, remapping tool/block IDs agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"}) diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 26b2d966..8bc6f21b 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -235,6 +235,15 @@ def embedding_handle_override(): return "openai/text-embedding-ada-002" +@pytest.fixture(scope="function") +def model_handle_override(): + # Use a different OpenAI model handle for override tests. + # The default in tests is usually gpt-4o-mini, so we use gpt-4o. + current_handle = LLMConfig.default_config("gpt-4o-mini").handle or "openai/gpt-4o-mini" + assert current_handle != "openai/gpt-4o" # make sure it's different + return "openai/gpt-4o" + + @pytest.fixture(scope="function") async def test_source(server: SyncServer, default_user): """Fixture to create and return a test source.""" @@ -1166,6 +1175,52 @@ class TestAgentFileImport: imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user) assert imported_agent.embedding_config.handle == embedding_handle_override + async def test_basic_import_with_model_override( + self, server, agent_serialization_manager, test_agent, default_user, other_user, model_handle_override + ): + """Test basic agent import functionality with LLM model override.""" + # Verify original agent has gpt-4o-mini (handle may be None for legacy configs) + assert "gpt-4o-mini" in (test_agent.llm_config.handle or "") or "gpt-4o-mini" in (test_agent.llm_config.model or "") + + agent_file = await agent_serialization_manager.export([test_agent.id], default_user) + + llm_config_override = await server.get_llm_config_from_handle_async(actor=other_user, handle=model_handle_override) + result = await agent_serialization_manager.import_file(agent_file, other_user, override_llm_config=llm_config_override) + + assert result.success + assert result.imported_count > 0 + assert len(result.id_mappings) > 0 + + for file_id, db_id in result.id_mappings.items(): + if file_id.startswith("agent-"): + assert db_id != test_agent.id # New agent should have different ID + + # check model handle was overridden + imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") + imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user) + assert imported_agent.llm_config.handle == model_handle_override + + async def test_basic_import_with_both_overrides( + self, server, agent_serialization_manager, test_agent, default_user, other_user, embedding_handle_override, model_handle_override + ): + """Test agent import with both embedding and model overrides.""" + agent_file = await agent_serialization_manager.export([test_agent.id], default_user) + + embedding_config_override = await server.get_embedding_config_from_handle_async(actor=other_user, handle=embedding_handle_override) + llm_config_override = await server.get_llm_config_from_handle_async(actor=other_user, handle=model_handle_override) + result = await agent_serialization_manager.import_file( + agent_file, other_user, override_embedding_config=embedding_config_override, override_llm_config=llm_config_override + ) + + assert result.success + assert result.imported_count > 0 + + # Verify both overrides were applied + imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") + imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user) + assert imported_agent.embedding_config.handle == embedding_handle_override + assert imported_agent.llm_config.handle == model_handle_override + async def test_import_preserves_data(self, server, agent_serialization_manager, test_agent, default_user, other_user): """Test that import preserves all important data.""" agent_file = await agent_serialization_manager.export([test_agent.id], default_user)