From a2f4ca5f890034cb3474ca32b94ebfc75d6c6abf Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 21 Aug 2025 16:23:37 -0700 Subject: [PATCH] fix: Fix bugs with exporting/importing agents with files (#4089) --- letta/schemas/agent_file.py | 2 +- letta/services/agent_serialization_manager.py | 33 ++++- letta/services/source_manager.py | 27 +++- poetry.lock | 6 +- tests/helpers/utils.py | 36 +++++ tests/test_managers.py | 53 +++++++ tests/test_sdk_client.py | 131 ++++++++++++++++++ tests/test_sources.py | 34 +---- 8 files changed, 279 insertions(+), 43 deletions(-) diff --git a/letta/schemas/agent_file.py b/letta/schemas/agent_file.py index 02d3988a..73477c2e 100644 --- a/letta/schemas/agent_file.py +++ b/letta/schemas/agent_file.py @@ -129,7 +129,7 @@ class AgentSchema(CreateAgent): memory_blocks=[], # TODO: Convert from agent_state.memory if needed tools=[], tool_ids=[tool.id for tool in agent_state.tools] if agent_state.tools else [], - source_ids=[], # [source.id for source in agent_state.sources] if agent_state.sources else [], + source_ids=[source.id for source in agent_state.sources] if agent_state.sources else [], block_ids=[block.id for block in agent_state.memory.blocks], tool_rules=agent_state.tool_rules, tags=agent_state.tags, diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 56f324d7..722801cc 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -1,3 +1,4 @@ +import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional @@ -519,8 +520,20 @@ class AgentSerializationManager: if schema.sources: # convert source schemas to pydantic sources pydantic_sources = [] + + # First, do a fast batch check for existing source names to avoid conflicts + source_names_to_check = [s.name for s in schema.sources] + existing_source_names = await self.source_manager.get_existing_source_names(source_names_to_check, actor) + for source_schema in schema.sources: source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"}) + + # Check if source name already exists, if so add unique suffix + original_name = source_data["name"] + if original_name in existing_source_names: + unique_suffix = uuid.uuid4().hex[:8] + source_data["name"] = f"{original_name}_{unique_suffix}" + pydantic_sources.append(Source(**source_data)) # bulk upsert all sources at once @@ -529,13 +542,15 @@ class AgentSerializationManager: # map file ids to database ids # note: sources are matched by name during upsert, so we need to match by name here too created_sources_by_name = {source.name: source for source in created_sources} - for source_schema in schema.sources: - created_source = created_sources_by_name.get(source_schema.name) + for i, source_schema in enumerate(schema.sources): + # Use the pydantic source name (which may have been modified for uniqueness) + source_name = pydantic_sources[i].name + created_source = created_sources_by_name.get(source_name) if created_source: file_to_db_ids[source_schema.id] = created_source.id imported_count += 1 else: - logger.warning(f"Source {source_schema.name} was not created during bulk upsert") + logger.warning(f"Source {source_name} was not created during bulk upsert") # 4. Create files (depends on sources) for file_schema in schema.files: @@ -595,6 +610,10 @@ class AgentSerializationManager: if agent_data.get("block_ids"): agent_data["block_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["block_ids"]] + # Remap source_ids from file IDs to database IDs + if agent_data.get("source_ids"): + agent_data["source_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["source_ids"]] + if env_vars: for var in agent_data["tool_exec_environment_variables"]: var["value"] = env_vars.get(var["key"], "") @@ -641,14 +660,16 @@ class AgentSerializationManager: for file_agent_schema in agent_schema.files_agents: file_db_id = file_to_db_ids[file_agent_schema.file_id] - # Use cached file metadata if available + # Use cached file metadata if available (with content) if file_db_id not in file_metadata_cache: - file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor) + file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id( + file_db_id, actor, include_content=True + ) file_metadata = file_metadata_cache[file_db_id] files_for_agent.append(file_metadata) if file_agent_schema.visible_content: - visible_content_map[file_db_id] = file_agent_schema.visible_content + visible_content_map[file_metadata.file_name] = file_agent_schema.visible_content # Bulk attach files to agent await self.file_agent_manager.attach_files_bulk( diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index dbab4f29..28b314b0 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -143,7 +143,6 @@ class SourceManager: update_dict[col.name] = excluded[col.name] upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict) - await session.execute(upsert_stmt) await session.commit() @@ -397,3 +396,29 @@ class SourceManager: sources_orm = result.scalars().all() return [source.to_pydantic() for source in sources_orm] + + @enforce_types + @trace_method + async def get_existing_source_names(self, source_names: List[str], actor: PydanticUser) -> set[str]: + """ + Fast batch check to see which source names already exist for the organization. + + Args: + source_names: List of source names to check + actor: User performing the action + + Returns: + Set of source names that already exist + """ + if not source_names: + return set() + + async with db_registry.async_session() as session: + query = select(SourceModel.name).where( + SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False + ) + + result = await session.execute(query) + existing_names = result.scalars().all() + + return set(existing_names) diff --git a/poetry.lock b/poetry.lock index 7a548938..3d2d408d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3473,13 +3473,13 @@ vcr = ["vcrpy (>=7.0.0)"] [[package]] name = "letta-client" -version = "0.1.271" +version = "0.1.272" description = "" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "letta_client-0.1.271-py3-none-any.whl", hash = "sha256:edbf6323e472202090113147b1c9ed280151d4966999686046d48c50c19c74fc"}, - {file = "letta_client-0.1.271.tar.gz", hash = "sha256:ae7944e594fe87dd80ce5057c42806e8c24b55e11f8fe6d05420fbc5af9b4180"}, + {file = "letta_client-0.1.272-py3-none-any.whl", hash = "sha256:ed5afffce9431e9dd1170c642efc68b1b5edadfe1923a467f017588dd371447e"}, + {file = "letta_client-0.1.272.tar.gz", hash = "sha256:40bb1e802aeabbb9cb6eaa2105eff7e8a704ac0962623e4b27d6320e57029dcc"}, ] [package.dependencies] diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 24467688..7ce5b989 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -278,3 +278,39 @@ async def upload_test_agentfile_from_disk_async(client: AsyncLetta, filename: st with open(file_path, "rb") as f: uploaded = await client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False) return uploaded + + +def upload_file_and_wait( + client: Letta, + source_id: str, + file_path: str, + name: Optional[str] = None, + max_wait: int = 60, + duplicate_handling: Optional["DuplicateFileHandling"] = None, +): + """Helper function to upload a file and wait for processing to complete""" + from letta_client import DuplicateFileHandling as ClientDuplicateFileHandling + + with open(file_path, "rb") as f: + if duplicate_handling: + # handle both client and server enum types + if hasattr(duplicate_handling, "value"): + # server enum type + duplicate_handling = ClientDuplicateFileHandling(duplicate_handling.value) + file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name) + else: + file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name) + + # wait for the file to be processed + start_time = time.time() + while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error": + if time.time() - start_time > max_wait: + raise TimeoutError(f"File processing timed out after {max_wait} seconds") + time.sleep(1) + file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id) + print("Waiting for file processing to complete...", file_metadata.processing_status) + + if file_metadata.processing_status == "error": + raise RuntimeError(f"File processing failed: {file_metadata.error_message}") + + return file_metadata diff --git a/tests/test_managers.py b/tests/test_managers.py index 02956d8f..adeadd41 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5698,6 +5698,59 @@ async def test_get_set_blocks_for_identities(server: SyncServer, default_block, # ====================================================================================================================== +@pytest.mark.asyncio +async def test_get_existing_source_names(server: SyncServer, default_user, event_loop): + """Test the fast batch check for existing source names.""" + # Create some test sources + source1 = PydanticSource( + name="test_source_1", + embedding_config=EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_model="text-embedding-ada-002", + embedding_dim=1536, + embedding_chunk_size=300, + ), + ) + source2 = PydanticSource( + name="test_source_2", + embedding_config=EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_model="text-embedding-ada-002", + embedding_dim=1536, + embedding_chunk_size=300, + ), + ) + + # Create the sources + created_source1 = await server.source_manager.create_source(source1, default_user) + created_source2 = await server.source_manager.create_source(source2, default_user) + + # Test batch check - mix of existing and non-existing names + names_to_check = ["test_source_1", "test_source_2", "non_existent_source", "another_non_existent"] + existing_names = await server.source_manager.get_existing_source_names(names_to_check, default_user) + + # Verify results + assert len(existing_names) == 2 + assert "test_source_1" in existing_names + assert "test_source_2" in existing_names + assert "non_existent_source" not in existing_names + assert "another_non_existent" not in existing_names + + # Test with empty list + empty_result = await server.source_manager.get_existing_source_names([], default_user) + assert len(empty_result) == 0 + + # Test with all non-existing names + non_existing_result = await server.source_manager.get_existing_source_names(["fake1", "fake2"], default_user) + assert len(non_existing_result) == 0 + + # Cleanup + await server.source_manager.delete_source(created_source1.id, default_user) + await server.source_manager.delete_source(created_source2.id, default_user) + + @pytest.mark.asyncio async def test_create_source(server: SyncServer, default_user, event_loop): """Test creating a new source.""" diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index b06f6934..56224581 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -17,6 +17,8 @@ from letta_client.core import ApiError from letta_client.types import AgentState, ToolReturnMessage from pydantic import BaseModel, Field +from tests.helpers.utils import upload_file_and_wait + # Constants SERVER_PORT = 8283 @@ -1869,3 +1871,132 @@ def test_agent_serialization_v2( if len(original_user_msgs) > 0 and len(imported_user_msgs) > 0: assert imported_user_msgs[0].content == original_user_msgs[0].content, "User message content not preserved" assert "Test message" in imported_user_msgs[0].content, "Test message content not found" + + +def test_export_import_agent_with_files(client: LettaSDKClient): + """Test exporting and importing an agent with files attached.""" + + # Clean up any existing source with the same name from previous runs + existing_sources = client.sources.list() + for existing_source in existing_sources: + client.sources.delete(source_id=existing_source.id) + + # Create a source and upload test files + source = client.sources.create(name="test_export_source", embedding="openai/text-embedding-3-small") + + # Upload test files to the source + test_files = ["tests/data/test.txt", "tests/data/test.md"] + + for file_path in test_files: + upload_file_and_wait(client, source.id, file_path) + + # Verify files were uploaded successfully + files_in_source = client.sources.files.list(source_id=source.id, limit=10) + assert len(files_in_source) == len(test_files), f"Expected {len(test_files)} files, got {len(files_in_source)}" + + # Create a simple agent with the source attached + temp_agent = client.agents.create( + memory_blocks=[ + CreateBlock(label="human", value="username: sarah"), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + source_ids=[source.id], # Attach the source with files + ) + + # Verify the agent has the source and file blocks + agent_state = client.agents.retrieve(agent_id=temp_agent.id) + assert len(agent_state.sources) == 1, "Agent should have one source attached" + assert agent_state.sources[0].id == source.id, "Agent should have the correct source attached" + + # Verify file blocks are present + file_blocks = agent_state.memory.file_blocks + assert len(file_blocks) == len(test_files), f"Expected {len(test_files)} file blocks, got {len(file_blocks)}" + + # Export the agent + serialized_agent = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False) + + # Convert to JSON bytes for import + json_str = json.dumps(serialized_agent) + file_obj = io.BytesIO(json_str.encode("utf-8")) + + # Import the agent + import_result = client.agents.import_file(file=file_obj, append_copy_suffix=True, override_existing_tools=True) + + # Verify import was successful + assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent" + imported_agent_id = import_result.agent_ids[0] + imported_agent = client.agents.retrieve(agent_id=imported_agent_id) + + # Verify the source is attached to the imported agent + assert len(imported_agent.sources) == 1, "Imported agent should have one source attached" + imported_source = imported_agent.sources[0] + + # Check that imported source has the same files + imported_files = client.sources.files.list(source_id=imported_source.id, limit=10) + assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files" + + # Verify file blocks are preserved in imported agent + imported_file_blocks = imported_agent.memory.file_blocks + assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks" + + # Verify file block content + for file_block in imported_file_blocks: + assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content" + assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header" + + # Test that files can be opened on the imported agent + if len(imported_files) > 0: + test_file = imported_files[0] + client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id) + + # Clean up + client.agents.delete(agent_id=temp_agent.id) + client.agents.delete(agent_id=imported_agent_id) + client.sources.delete(source_id=source.id) + + +def test_import_agent_with_files_from_disk(client: LettaSDKClient): + """Test exporting an agent with files to disk and importing it back.""" + # Upload test files to the source + test_files = ["tests/data/test.txt", "tests/data/test.md"] + + # Save to file + file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent_with_files_and_sources.af") + + # Now import from the file + with open(file_path, "rb") as f: + import_result = client.agents.import_file( + file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict + ) + + # Verify import was successful + assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent" + imported_agent_id = import_result.agent_ids[0] + imported_agent = client.agents.retrieve(agent_id=imported_agent_id) + + # Verify the source is attached to the imported agent + assert len(imported_agent.sources) == 1, "Imported agent should have one source attached" + imported_source = imported_agent.sources[0] + + # Check that imported source has the same files + imported_files = client.sources.files.list(source_id=imported_source.id, limit=10) + assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files" + + # Verify file blocks are preserved in imported agent + imported_file_blocks = imported_agent.memory.file_blocks + assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks" + + # Verify file block content + for file_block in imported_file_blocks: + assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content" + assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header" + + # Test that files can be opened on the imported agent + if len(imported_files) > 0: + test_file = imported_files[0] + client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id) + + # Clean up agents and sources + client.agents.delete(agent_id=imported_agent_id) + client.sources.delete(source_id=imported_source.id) diff --git a/tests/test_sources.py b/tests/test_sources.py index 905e1297..cb0d2d1e 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -4,11 +4,10 @@ import tempfile import threading import time from datetime import datetime, timedelta -from typing import Optional import pytest from dotenv import load_dotenv -from letta_client import CreateBlock, DuplicateFileHandling +from letta_client import CreateBlock from letta_client import Letta as LettaSDKClient from letta_client import LettaRequest from letta_client import MessageCreate as ClientMessageCreate @@ -19,6 +18,7 @@ from letta.schemas.enums import FileProcessingStatus, ToolType from letta.schemas.message import MessageCreate from letta.schemas.user import User from letta.settings import settings +from tests.helpers.utils import upload_file_and_wait from tests.utils import wait_for_server # Constants @@ -72,36 +72,6 @@ def client() -> LettaSDKClient: yield client -def upload_file_and_wait( - client: LettaSDKClient, - source_id: str, - file_path: str, - name: Optional[str] = None, - max_wait: int = 60, - duplicate_handling: DuplicateFileHandling = None, -): - """Helper function to upload a file and wait for processing to complete""" - with open(file_path, "rb") as f: - if duplicate_handling: - file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name) - else: - file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name) - - # Wait for the file to be processed - start_time = time.time() - while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error": - if time.time() - start_time > max_wait: - pytest.fail(f"File processing timed out after {max_wait} seconds") - time.sleep(1) - file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id) - print("Waiting for file processing to complete...", file_metadata.processing_status) - - if file_metadata.processing_status == "error": - pytest.fail(f"File processing failed: {file_metadata.error_message}") - - return file_metadata - - @pytest.fixture def agent_state(disable_pinecone, client: LettaSDKClient): open_file_tool = client.tools.list(name="open_files")[0]