fix: Fix bugs with exporting/importing agents with files (#4089)

This commit is contained in:
Matthew Zhou
2025-08-21 16:23:37 -07:00
committed by GitHub
parent a4cd4a9487
commit a2f4ca5f89
8 changed files with 279 additions and 43 deletions

View File

@@ -129,7 +129,7 @@ class AgentSchema(CreateAgent):
memory_blocks=[], # TODO: Convert from agent_state.memory if needed
tools=[],
tool_ids=[tool.id for tool in agent_state.tools] if agent_state.tools else [],
source_ids=[], # [source.id for source in agent_state.sources] if agent_state.sources else [],
source_ids=[source.id for source in agent_state.sources] if agent_state.sources else [],
block_ids=[block.id for block in agent_state.memory.blocks],
tool_rules=agent_state.tool_rules,
tags=agent_state.tags,

View File

@@ -1,3 +1,4 @@
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
@@ -519,8 +520,20 @@ class AgentSerializationManager:
if schema.sources:
# convert source schemas to pydantic sources
pydantic_sources = []
# First, do a fast batch check for existing source names to avoid conflicts
source_names_to_check = [s.name for s in schema.sources]
existing_source_names = await self.source_manager.get_existing_source_names(source_names_to_check, actor)
for source_schema in schema.sources:
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
# Check if source name already exists, if so add unique suffix
original_name = source_data["name"]
if original_name in existing_source_names:
unique_suffix = uuid.uuid4().hex[:8]
source_data["name"] = f"{original_name}_{unique_suffix}"
pydantic_sources.append(Source(**source_data))
# bulk upsert all sources at once
@@ -529,13 +542,15 @@ class AgentSerializationManager:
# map file ids to database ids
# note: sources are matched by name during upsert, so we need to match by name here too
created_sources_by_name = {source.name: source for source in created_sources}
for source_schema in schema.sources:
created_source = created_sources_by_name.get(source_schema.name)
for i, source_schema in enumerate(schema.sources):
# Use the pydantic source name (which may have been modified for uniqueness)
source_name = pydantic_sources[i].name
created_source = created_sources_by_name.get(source_name)
if created_source:
file_to_db_ids[source_schema.id] = created_source.id
imported_count += 1
else:
logger.warning(f"Source {source_schema.name} was not created during bulk upsert")
logger.warning(f"Source {source_name} was not created during bulk upsert")
# 4. Create files (depends on sources)
for file_schema in schema.files:
@@ -595,6 +610,10 @@ class AgentSerializationManager:
if agent_data.get("block_ids"):
agent_data["block_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["block_ids"]]
# Remap source_ids from file IDs to database IDs
if agent_data.get("source_ids"):
agent_data["source_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["source_ids"]]
if env_vars:
for var in agent_data["tool_exec_environment_variables"]:
var["value"] = env_vars.get(var["key"], "")
@@ -641,14 +660,16 @@ class AgentSerializationManager:
for file_agent_schema in agent_schema.files_agents:
file_db_id = file_to_db_ids[file_agent_schema.file_id]
# Use cached file metadata if available
# Use cached file metadata if available (with content)
if file_db_id not in file_metadata_cache:
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(
file_db_id, actor, include_content=True
)
file_metadata = file_metadata_cache[file_db_id]
files_for_agent.append(file_metadata)
if file_agent_schema.visible_content:
visible_content_map[file_db_id] = file_agent_schema.visible_content
visible_content_map[file_metadata.file_name] = file_agent_schema.visible_content
# Bulk attach files to agent
await self.file_agent_manager.attach_files_bulk(

View File

@@ -143,7 +143,6 @@ class SourceManager:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
await session.execute(upsert_stmt)
await session.commit()
@@ -397,3 +396,29 @@ class SourceManager:
sources_orm = result.scalars().all()
return [source.to_pydantic() for source in sources_orm]
@enforce_types
@trace_method
async def get_existing_source_names(self, source_names: List[str], actor: PydanticUser) -> set[str]:
"""
Fast batch check to see which source names already exist for the organization.
Args:
source_names: List of source names to check
actor: User performing the action
Returns:
Set of source names that already exist
"""
if not source_names:
return set()
async with db_registry.async_session() as session:
query = select(SourceModel.name).where(
SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
)
result = await session.execute(query)
existing_names = result.scalars().all()
return set(existing_names)

6
poetry.lock generated
View File

@@ -3473,13 +3473,13 @@ vcr = ["vcrpy (>=7.0.0)"]
[[package]]
name = "letta-client"
version = "0.1.271"
version = "0.1.272"
description = ""
optional = false
python-versions = "<4.0,>=3.8"
files = [
{file = "letta_client-0.1.271-py3-none-any.whl", hash = "sha256:edbf6323e472202090113147b1c9ed280151d4966999686046d48c50c19c74fc"},
{file = "letta_client-0.1.271.tar.gz", hash = "sha256:ae7944e594fe87dd80ce5057c42806e8c24b55e11f8fe6d05420fbc5af9b4180"},
{file = "letta_client-0.1.272-py3-none-any.whl", hash = "sha256:ed5afffce9431e9dd1170c642efc68b1b5edadfe1923a467f017588dd371447e"},
{file = "letta_client-0.1.272.tar.gz", hash = "sha256:40bb1e802aeabbb9cb6eaa2105eff7e8a704ac0962623e4b27d6320e57029dcc"},
]
[package.dependencies]

View File

@@ -278,3 +278,39 @@ async def upload_test_agentfile_from_disk_async(client: AsyncLetta, filename: st
with open(file_path, "rb") as f:
uploaded = await client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
return uploaded
def upload_file_and_wait(
client: Letta,
source_id: str,
file_path: str,
name: Optional[str] = None,
max_wait: int = 60,
duplicate_handling: Optional["DuplicateFileHandling"] = None,
):
"""Helper function to upload a file and wait for processing to complete"""
from letta_client import DuplicateFileHandling as ClientDuplicateFileHandling
with open(file_path, "rb") as f:
if duplicate_handling:
# handle both client and server enum types
if hasattr(duplicate_handling, "value"):
# server enum type
duplicate_handling = ClientDuplicateFileHandling(duplicate_handling.value)
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
else:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
# wait for the file to be processed
start_time = time.time()
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
if time.time() - start_time > max_wait:
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
time.sleep(1)
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
print("Waiting for file processing to complete...", file_metadata.processing_status)
if file_metadata.processing_status == "error":
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
return file_metadata

View File

@@ -5698,6 +5698,59 @@ async def test_get_set_blocks_for_identities(server: SyncServer, default_block,
# ======================================================================================================================
@pytest.mark.asyncio
async def test_get_existing_source_names(server: SyncServer, default_user, event_loop):
"""Test the fast batch check for existing source names."""
# Create some test sources
source1 = PydanticSource(
name="test_source_1",
embedding_config=EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-ada-002",
embedding_dim=1536,
embedding_chunk_size=300,
),
)
source2 = PydanticSource(
name="test_source_2",
embedding_config=EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-ada-002",
embedding_dim=1536,
embedding_chunk_size=300,
),
)
# Create the sources
created_source1 = await server.source_manager.create_source(source1, default_user)
created_source2 = await server.source_manager.create_source(source2, default_user)
# Test batch check - mix of existing and non-existing names
names_to_check = ["test_source_1", "test_source_2", "non_existent_source", "another_non_existent"]
existing_names = await server.source_manager.get_existing_source_names(names_to_check, default_user)
# Verify results
assert len(existing_names) == 2
assert "test_source_1" in existing_names
assert "test_source_2" in existing_names
assert "non_existent_source" not in existing_names
assert "another_non_existent" not in existing_names
# Test with empty list
empty_result = await server.source_manager.get_existing_source_names([], default_user)
assert len(empty_result) == 0
# Test with all non-existing names
non_existing_result = await server.source_manager.get_existing_source_names(["fake1", "fake2"], default_user)
assert len(non_existing_result) == 0
# Cleanup
await server.source_manager.delete_source(created_source1.id, default_user)
await server.source_manager.delete_source(created_source2.id, default_user)
@pytest.mark.asyncio
async def test_create_source(server: SyncServer, default_user, event_loop):
"""Test creating a new source."""

View File

@@ -17,6 +17,8 @@ from letta_client.core import ApiError
from letta_client.types import AgentState, ToolReturnMessage
from pydantic import BaseModel, Field
from tests.helpers.utils import upload_file_and_wait
# Constants
SERVER_PORT = 8283
@@ -1869,3 +1871,132 @@ def test_agent_serialization_v2(
if len(original_user_msgs) > 0 and len(imported_user_msgs) > 0:
assert imported_user_msgs[0].content == original_user_msgs[0].content, "User message content not preserved"
assert "Test message" in imported_user_msgs[0].content, "Test message content not found"
def test_export_import_agent_with_files(client: LettaSDKClient):
"""Test exporting and importing an agent with files attached."""
# Clean up any existing source with the same name from previous runs
existing_sources = client.sources.list()
for existing_source in existing_sources:
client.sources.delete(source_id=existing_source.id)
# Create a source and upload test files
source = client.sources.create(name="test_export_source", embedding="openai/text-embedding-3-small")
# Upload test files to the source
test_files = ["tests/data/test.txt", "tests/data/test.md"]
for file_path in test_files:
upload_file_and_wait(client, source.id, file_path)
# Verify files were uploaded successfully
files_in_source = client.sources.files.list(source_id=source.id, limit=10)
assert len(files_in_source) == len(test_files), f"Expected {len(test_files)} files, got {len(files_in_source)}"
# Create a simple agent with the source attached
temp_agent = client.agents.create(
memory_blocks=[
CreateBlock(label="human", value="username: sarah"),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
source_ids=[source.id], # Attach the source with files
)
# Verify the agent has the source and file blocks
agent_state = client.agents.retrieve(agent_id=temp_agent.id)
assert len(agent_state.sources) == 1, "Agent should have one source attached"
assert agent_state.sources[0].id == source.id, "Agent should have the correct source attached"
# Verify file blocks are present
file_blocks = agent_state.memory.file_blocks
assert len(file_blocks) == len(test_files), f"Expected {len(test_files)} file blocks, got {len(file_blocks)}"
# Export the agent
serialized_agent = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
# Convert to JSON bytes for import
json_str = json.dumps(serialized_agent)
file_obj = io.BytesIO(json_str.encode("utf-8"))
# Import the agent
import_result = client.agents.import_file(file=file_obj, append_copy_suffix=True, override_existing_tools=True)
# Verify import was successful
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
imported_agent_id = import_result.agent_ids[0]
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
# Verify the source is attached to the imported agent
assert len(imported_agent.sources) == 1, "Imported agent should have one source attached"
imported_source = imported_agent.sources[0]
# Check that imported source has the same files
imported_files = client.sources.files.list(source_id=imported_source.id, limit=10)
assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files"
# Verify file blocks are preserved in imported agent
imported_file_blocks = imported_agent.memory.file_blocks
assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks"
# Verify file block content
for file_block in imported_file_blocks:
assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content"
assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header"
# Test that files can be opened on the imported agent
if len(imported_files) > 0:
test_file = imported_files[0]
client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id)
# Clean up
client.agents.delete(agent_id=temp_agent.id)
client.agents.delete(agent_id=imported_agent_id)
client.sources.delete(source_id=source.id)
def test_import_agent_with_files_from_disk(client: LettaSDKClient):
"""Test exporting an agent with files to disk and importing it back."""
# Upload test files to the source
test_files = ["tests/data/test.txt", "tests/data/test.md"]
# Save to file
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent_with_files_and_sources.af")
# Now import from the file
with open(file_path, "rb") as f:
import_result = client.agents.import_file(
file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict
)
# Verify import was successful
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
imported_agent_id = import_result.agent_ids[0]
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
# Verify the source is attached to the imported agent
assert len(imported_agent.sources) == 1, "Imported agent should have one source attached"
imported_source = imported_agent.sources[0]
# Check that imported source has the same files
imported_files = client.sources.files.list(source_id=imported_source.id, limit=10)
assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files"
# Verify file blocks are preserved in imported agent
imported_file_blocks = imported_agent.memory.file_blocks
assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks"
# Verify file block content
for file_block in imported_file_blocks:
assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content"
assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header"
# Test that files can be opened on the imported agent
if len(imported_files) > 0:
test_file = imported_files[0]
client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id)
# Clean up agents and sources
client.agents.delete(agent_id=imported_agent_id)
client.sources.delete(source_id=imported_source.id)

View File

@@ -4,11 +4,10 @@ import tempfile
import threading
import time
from datetime import datetime, timedelta
from typing import Optional
import pytest
from dotenv import load_dotenv
from letta_client import CreateBlock, DuplicateFileHandling
from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client import LettaRequest
from letta_client import MessageCreate as ClientMessageCreate
@@ -19,6 +18,7 @@ from letta.schemas.enums import FileProcessingStatus, ToolType
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
from letta.settings import settings
from tests.helpers.utils import upload_file_and_wait
from tests.utils import wait_for_server
# Constants
@@ -72,36 +72,6 @@ def client() -> LettaSDKClient:
yield client
def upload_file_and_wait(
client: LettaSDKClient,
source_id: str,
file_path: str,
name: Optional[str] = None,
max_wait: int = 60,
duplicate_handling: DuplicateFileHandling = None,
):
"""Helper function to upload a file and wait for processing to complete"""
with open(file_path, "rb") as f:
if duplicate_handling:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
else:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
# Wait for the file to be processed
start_time = time.time()
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
if time.time() - start_time > max_wait:
pytest.fail(f"File processing timed out after {max_wait} seconds")
time.sleep(1)
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
print("Waiting for file processing to complete...", file_metadata.processing_status)
if file_metadata.processing_status == "error":
pytest.fail(f"File processing failed: {file_metadata.error_message}")
return file_metadata
@pytest.fixture
def agent_state(disable_pinecone, client: LettaSDKClient):
open_file_tool = client.tools.list(name="open_files")[0]