From 6e628a93f73a8dd514dfef5cff8770affb691565 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 27 Aug 2025 14:54:17 -0700 Subject: [PATCH] feat: support overriding embedding handle [LET-4021] (#4224) --- letta/server/rest_api/routers/v1/agents.py | 11 +++++++ letta/services/agent_serialization_manager.py | 19 ++++++++++-- tests/test_agent_serialization_v2.py | 29 +++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 5dad005a..987e1dc5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -252,6 +252,7 @@ async def import_agent( project_id: str | None = None, strip_messages: bool = False, env_vars: Optional[dict[str, Any]] = None, + override_embedding_handle: Optional[str] = None, ) -> List[str]: """ Import an agent using the new AgentFileSchema format. @@ -262,12 +263,18 @@ async def import_agent( raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}") try: + if override_embedding_handle: + embedding_config_override = server.get_cached_embedding_config_async(actor=actor, handle=override_embedding_handle) + else: + embedding_config_override = None + import_result = await server.agent_serialization_manager.import_file( schema=agent_schema, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools, env_vars=env_vars, + override_embedding_config=embedding_config_override, ) if not import_result.success: @@ -301,6 +308,10 @@ async def import_agent_serialized( True, description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.", ), + override_embedding_handle: Optional[str] = Form( + None, + description="Override import with specific embedding handle.", + ), project_id: str | None = Form(None, description="The project ID to associate the uploaded agent with."), strip_messages: bool = Form( False, diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 5635ed3b..047fd332 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -28,6 +28,7 @@ from letta.schemas.agent_file import ( ToolSchema, ) from letta.schemas.block import Block +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata from letta.schemas.group import Group, GroupCreate @@ -432,6 +433,7 @@ class AgentSerializationManager: override_existing_tools: bool = True, dry_run: bool = False, env_vars: Optional[Dict[str, Any]] = None, + override_embedding_config: Optional[EmbeddingConfig] = None, ) -> ImportResult: """ Import AgentFileSchema into the database. @@ -530,6 +532,12 @@ class AgentSerializationManager: source_names_to_check = [s.name for s in schema.sources] existing_source_names = await self.source_manager.get_existing_source_names(source_names_to_check, actor) + # override embedding_config + if override_embedding_config: + for source_schema in schema.sources: + source_schema.embedding_config = override_embedding_config + source_schema.embedding = override_embedding_config.handle + for source_schema in schema.sources: source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"}) @@ -577,10 +585,12 @@ class AgentSerializationManager: # Start background tasks for file processing background_tasks = [] if schema.files and any(f.content for f in schema.files): + # Use override embedding config if provided, otherwise use agent's config + embedder_config = override_embedding_config if override_embedding_config else schema.agents[0].embedding_config if should_use_pinecone(): - embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config) + embedder = PineconeEmbedder(embedding_config=embedder_config) else: - embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config) + embedder = OpenAIEmbedder(embedding_config=embedder_config) file_processor = FileProcessor( file_parser=self.file_parser, embedder=embedder, @@ -613,6 +623,11 @@ class AgentSerializationManager: # 6. Create agents with empty message history for agent_schema in schema.agents: + # Override embedding_config if provided + if override_embedding_config: + agent_schema.embedding_config = override_embedding_config + agent_schema.embedding = override_embedding_config.handle + # Convert AgentSchema back to CreateAgent, remapping tool/block IDs agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"}) if append_copy_suffix: diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 7f6b1054..3a7c004d 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -205,6 +205,13 @@ def test_agent(server: SyncServer, default_user, default_organization, test_bloc yield agent_state +@pytest.fixture(scope="function") +def embedding_handle_override(): + current_handle = EmbeddingConfig.default_config(provider="openai").handle + assert current_handle != "letta/letta-free" # make sure its different + return "letta/letta-free" + + @pytest.fixture(scope="function") async def test_source(server: SyncServer, default_user): """Fixture to create and return a test source.""" @@ -1063,6 +1070,28 @@ class TestAgentFileImport: if file_id.startswith("agent-"): assert db_id != test_agent.id # New agent should have different ID + async def test_basic_import_with_embedding_override( + self, server, agent_serialization_manager, test_agent, default_user, other_user, embedding_handle_override + ): + """Test basic agent import functionality with embedding override.""" + agent_file = await agent_serialization_manager.export([test_agent.id], default_user) + + embedding_config_override = await server.get_cached_embedding_config_async(actor=other_user, handle=embedding_handle_override) + result = await agent_serialization_manager.import_file(agent_file, other_user, override_embedding_config=embedding_config_override) + + assert result.success + assert result.imported_count > 0 + assert len(result.id_mappings) > 0 + + for file_id, db_id in result.id_mappings.items(): + if file_id.startswith("agent-"): + assert db_id != test_agent.id # New agent should have different ID + + # check embedding handle + imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") + imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user) + assert imported_agent.embedding_config.handle == embedding_handle_override + async def test_import_preserves_data(self, server, agent_serialization_manager, test_agent, default_user, other_user): """Test that import preserves all important data.""" agent_file = await agent_serialization_manager.export([test_agent.id], default_user)