fix: Fix bugs with exporting/importing agents with files (#4089)
This commit is contained in:
@@ -129,7 +129,7 @@ class AgentSchema(CreateAgent):
|
||||
memory_blocks=[], # TODO: Convert from agent_state.memory if needed
|
||||
tools=[],
|
||||
tool_ids=[tool.id for tool in agent_state.tools] if agent_state.tools else [],
|
||||
source_ids=[], # [source.id for source in agent_state.sources] if agent_state.sources else [],
|
||||
source_ids=[source.id for source in agent_state.sources] if agent_state.sources else [],
|
||||
block_ids=[block.id for block in agent_state.memory.blocks],
|
||||
tool_rules=agent_state.tool_rules,
|
||||
tags=agent_state.tags,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -519,8 +520,20 @@ class AgentSerializationManager:
|
||||
if schema.sources:
|
||||
# convert source schemas to pydantic sources
|
||||
pydantic_sources = []
|
||||
|
||||
# First, do a fast batch check for existing source names to avoid conflicts
|
||||
source_names_to_check = [s.name for s in schema.sources]
|
||||
existing_source_names = await self.source_manager.get_existing_source_names(source_names_to_check, actor)
|
||||
|
||||
for source_schema in schema.sources:
|
||||
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
|
||||
|
||||
# Check if source name already exists, if so add unique suffix
|
||||
original_name = source_data["name"]
|
||||
if original_name in existing_source_names:
|
||||
unique_suffix = uuid.uuid4().hex[:8]
|
||||
source_data["name"] = f"{original_name}_{unique_suffix}"
|
||||
|
||||
pydantic_sources.append(Source(**source_data))
|
||||
|
||||
# bulk upsert all sources at once
|
||||
@@ -529,13 +542,15 @@ class AgentSerializationManager:
|
||||
# map file ids to database ids
|
||||
# note: sources are matched by name during upsert, so we need to match by name here too
|
||||
created_sources_by_name = {source.name: source for source in created_sources}
|
||||
for source_schema in schema.sources:
|
||||
created_source = created_sources_by_name.get(source_schema.name)
|
||||
for i, source_schema in enumerate(schema.sources):
|
||||
# Use the pydantic source name (which may have been modified for uniqueness)
|
||||
source_name = pydantic_sources[i].name
|
||||
created_source = created_sources_by_name.get(source_name)
|
||||
if created_source:
|
||||
file_to_db_ids[source_schema.id] = created_source.id
|
||||
imported_count += 1
|
||||
else:
|
||||
logger.warning(f"Source {source_schema.name} was not created during bulk upsert")
|
||||
logger.warning(f"Source {source_name} was not created during bulk upsert")
|
||||
|
||||
# 4. Create files (depends on sources)
|
||||
for file_schema in schema.files:
|
||||
@@ -595,6 +610,10 @@ class AgentSerializationManager:
|
||||
if agent_data.get("block_ids"):
|
||||
agent_data["block_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["block_ids"]]
|
||||
|
||||
# Remap source_ids from file IDs to database IDs
|
||||
if agent_data.get("source_ids"):
|
||||
agent_data["source_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["source_ids"]]
|
||||
|
||||
if env_vars:
|
||||
for var in agent_data["tool_exec_environment_variables"]:
|
||||
var["value"] = env_vars.get(var["key"], "")
|
||||
@@ -641,14 +660,16 @@ class AgentSerializationManager:
|
||||
for file_agent_schema in agent_schema.files_agents:
|
||||
file_db_id = file_to_db_ids[file_agent_schema.file_id]
|
||||
|
||||
# Use cached file metadata if available
|
||||
# Use cached file metadata if available (with content)
|
||||
if file_db_id not in file_metadata_cache:
|
||||
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
|
||||
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(
|
||||
file_db_id, actor, include_content=True
|
||||
)
|
||||
file_metadata = file_metadata_cache[file_db_id]
|
||||
files_for_agent.append(file_metadata)
|
||||
|
||||
if file_agent_schema.visible_content:
|
||||
visible_content_map[file_db_id] = file_agent_schema.visible_content
|
||||
visible_content_map[file_metadata.file_name] = file_agent_schema.visible_content
|
||||
|
||||
# Bulk attach files to agent
|
||||
await self.file_agent_manager.attach_files_bulk(
|
||||
|
||||
@@ -143,7 +143,6 @@ class SourceManager:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
|
||||
await session.execute(upsert_stmt)
|
||||
await session.commit()
|
||||
|
||||
@@ -397,3 +396,29 @@ class SourceManager:
|
||||
sources_orm = result.scalars().all()
|
||||
|
||||
return [source.to_pydantic() for source in sources_orm]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_existing_source_names(self, source_names: List[str], actor: PydanticUser) -> set[str]:
|
||||
"""
|
||||
Fast batch check to see which source names already exist for the organization.
|
||||
|
||||
Args:
|
||||
source_names: List of source names to check
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
Set of source names that already exist
|
||||
"""
|
||||
if not source_names:
|
||||
return set()
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(SourceModel.name).where(
|
||||
SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
existing_names = result.scalars().all()
|
||||
|
||||
return set(existing_names)
|
||||
|
||||
6
poetry.lock
generated
6
poetry.lock
generated
@@ -3473,13 +3473,13 @@ vcr = ["vcrpy (>=7.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "letta-client"
|
||||
version = "0.1.271"
|
||||
version = "0.1.272"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "letta_client-0.1.271-py3-none-any.whl", hash = "sha256:edbf6323e472202090113147b1c9ed280151d4966999686046d48c50c19c74fc"},
|
||||
{file = "letta_client-0.1.271.tar.gz", hash = "sha256:ae7944e594fe87dd80ce5057c42806e8c24b55e11f8fe6d05420fbc5af9b4180"},
|
||||
{file = "letta_client-0.1.272-py3-none-any.whl", hash = "sha256:ed5afffce9431e9dd1170c642efc68b1b5edadfe1923a467f017588dd371447e"},
|
||||
{file = "letta_client-0.1.272.tar.gz", hash = "sha256:40bb1e802aeabbb9cb6eaa2105eff7e8a704ac0962623e4b27d6320e57029dcc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -278,3 +278,39 @@ async def upload_test_agentfile_from_disk_async(client: AsyncLetta, filename: st
|
||||
with open(file_path, "rb") as f:
|
||||
uploaded = await client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
|
||||
return uploaded
|
||||
|
||||
|
||||
def upload_file_and_wait(
|
||||
client: Letta,
|
||||
source_id: str,
|
||||
file_path: str,
|
||||
name: Optional[str] = None,
|
||||
max_wait: int = 60,
|
||||
duplicate_handling: Optional["DuplicateFileHandling"] = None,
|
||||
):
|
||||
"""Helper function to upload a file and wait for processing to complete"""
|
||||
from letta_client import DuplicateFileHandling as ClientDuplicateFileHandling
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
if duplicate_handling:
|
||||
# handle both client and server enum types
|
||||
if hasattr(duplicate_handling, "value"):
|
||||
# server enum type
|
||||
duplicate_handling = ClientDuplicateFileHandling(duplicate_handling.value)
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
|
||||
else:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
|
||||
|
||||
# wait for the file to be processed
|
||||
start_time = time.time()
|
||||
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
|
||||
if time.time() - start_time > max_wait:
|
||||
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
|
||||
time.sleep(1)
|
||||
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
|
||||
print("Waiting for file processing to complete...", file_metadata.processing_status)
|
||||
|
||||
if file_metadata.processing_status == "error":
|
||||
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
|
||||
|
||||
return file_metadata
|
||||
|
||||
720
tests/test_agent_files/test_agent_with_files_and_sources.af
Normal file
720
tests/test_agent_files/test_agent_with_files_and_sources.af
Normal file
File diff suppressed because one or more lines are too long
@@ -5698,6 +5698,59 @@ async def test_get_set_blocks_for_identities(server: SyncServer, default_block,
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_source_names(server: SyncServer, default_user, event_loop):
|
||||
"""Test the fast batch check for existing source names."""
|
||||
# Create some test sources
|
||||
source1 = PydanticSource(
|
||||
name="test_source_1",
|
||||
embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
),
|
||||
)
|
||||
source2 = PydanticSource(
|
||||
name="test_source_2",
|
||||
embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
),
|
||||
)
|
||||
|
||||
# Create the sources
|
||||
created_source1 = await server.source_manager.create_source(source1, default_user)
|
||||
created_source2 = await server.source_manager.create_source(source2, default_user)
|
||||
|
||||
# Test batch check - mix of existing and non-existing names
|
||||
names_to_check = ["test_source_1", "test_source_2", "non_existent_source", "another_non_existent"]
|
||||
existing_names = await server.source_manager.get_existing_source_names(names_to_check, default_user)
|
||||
|
||||
# Verify results
|
||||
assert len(existing_names) == 2
|
||||
assert "test_source_1" in existing_names
|
||||
assert "test_source_2" in existing_names
|
||||
assert "non_existent_source" not in existing_names
|
||||
assert "another_non_existent" not in existing_names
|
||||
|
||||
# Test with empty list
|
||||
empty_result = await server.source_manager.get_existing_source_names([], default_user)
|
||||
assert len(empty_result) == 0
|
||||
|
||||
# Test with all non-existing names
|
||||
non_existing_result = await server.source_manager.get_existing_source_names(["fake1", "fake2"], default_user)
|
||||
assert len(non_existing_result) == 0
|
||||
|
||||
# Cleanup
|
||||
await server.source_manager.delete_source(created_source1.id, default_user)
|
||||
await server.source_manager.delete_source(created_source2.id, default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_source(server: SyncServer, default_user, event_loop):
|
||||
"""Test creating a new source."""
|
||||
|
||||
@@ -17,6 +17,8 @@ from letta_client.core import ApiError
|
||||
from letta_client.types import AgentState, ToolReturnMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from tests.helpers.utils import upload_file_and_wait
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
|
||||
@@ -1869,3 +1871,132 @@ def test_agent_serialization_v2(
|
||||
if len(original_user_msgs) > 0 and len(imported_user_msgs) > 0:
|
||||
assert imported_user_msgs[0].content == original_user_msgs[0].content, "User message content not preserved"
|
||||
assert "Test message" in imported_user_msgs[0].content, "Test message content not found"
|
||||
|
||||
|
||||
def test_export_import_agent_with_files(client: LettaSDKClient):
|
||||
"""Test exporting and importing an agent with files attached."""
|
||||
|
||||
# Clean up any existing source with the same name from previous runs
|
||||
existing_sources = client.sources.list()
|
||||
for existing_source in existing_sources:
|
||||
client.sources.delete(source_id=existing_source.id)
|
||||
|
||||
# Create a source and upload test files
|
||||
source = client.sources.create(name="test_export_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Upload test files to the source
|
||||
test_files = ["tests/data/test.txt", "tests/data/test.md"]
|
||||
|
||||
for file_path in test_files:
|
||||
upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# Verify files were uploaded successfully
|
||||
files_in_source = client.sources.files.list(source_id=source.id, limit=10)
|
||||
assert len(files_in_source) == len(test_files), f"Expected {len(test_files)} files, got {len(files_in_source)}"
|
||||
|
||||
# Create a simple agent with the source attached
|
||||
temp_agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="username: sarah"),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
source_ids=[source.id], # Attach the source with files
|
||||
)
|
||||
|
||||
# Verify the agent has the source and file blocks
|
||||
agent_state = client.agents.retrieve(agent_id=temp_agent.id)
|
||||
assert len(agent_state.sources) == 1, "Agent should have one source attached"
|
||||
assert agent_state.sources[0].id == source.id, "Agent should have the correct source attached"
|
||||
|
||||
# Verify file blocks are present
|
||||
file_blocks = agent_state.memory.file_blocks
|
||||
assert len(file_blocks) == len(test_files), f"Expected {len(test_files)} file blocks, got {len(file_blocks)}"
|
||||
|
||||
# Export the agent
|
||||
serialized_agent = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
|
||||
|
||||
# Convert to JSON bytes for import
|
||||
json_str = json.dumps(serialized_agent)
|
||||
file_obj = io.BytesIO(json_str.encode("utf-8"))
|
||||
|
||||
# Import the agent
|
||||
import_result = client.agents.import_file(file=file_obj, append_copy_suffix=True, override_existing_tools=True)
|
||||
|
||||
# Verify import was successful
|
||||
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
|
||||
imported_agent_id = import_result.agent_ids[0]
|
||||
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
|
||||
|
||||
# Verify the source is attached to the imported agent
|
||||
assert len(imported_agent.sources) == 1, "Imported agent should have one source attached"
|
||||
imported_source = imported_agent.sources[0]
|
||||
|
||||
# Check that imported source has the same files
|
||||
imported_files = client.sources.files.list(source_id=imported_source.id, limit=10)
|
||||
assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files"
|
||||
|
||||
# Verify file blocks are preserved in imported agent
|
||||
imported_file_blocks = imported_agent.memory.file_blocks
|
||||
assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks"
|
||||
|
||||
# Verify file block content
|
||||
for file_block in imported_file_blocks:
|
||||
assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content"
|
||||
assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header"
|
||||
|
||||
# Test that files can be opened on the imported agent
|
||||
if len(imported_files) > 0:
|
||||
test_file = imported_files[0]
|
||||
client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id)
|
||||
|
||||
# Clean up
|
||||
client.agents.delete(agent_id=temp_agent.id)
|
||||
client.agents.delete(agent_id=imported_agent_id)
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
|
||||
def test_import_agent_with_files_from_disk(client: LettaSDKClient):
|
||||
"""Test exporting an agent with files to disk and importing it back."""
|
||||
# Upload test files to the source
|
||||
test_files = ["tests/data/test.txt", "tests/data/test.md"]
|
||||
|
||||
# Save to file
|
||||
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent_with_files_and_sources.af")
|
||||
|
||||
# Now import from the file
|
||||
with open(file_path, "rb") as f:
|
||||
import_result = client.agents.import_file(
|
||||
file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict
|
||||
)
|
||||
|
||||
# Verify import was successful
|
||||
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
|
||||
imported_agent_id = import_result.agent_ids[0]
|
||||
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
|
||||
|
||||
# Verify the source is attached to the imported agent
|
||||
assert len(imported_agent.sources) == 1, "Imported agent should have one source attached"
|
||||
imported_source = imported_agent.sources[0]
|
||||
|
||||
# Check that imported source has the same files
|
||||
imported_files = client.sources.files.list(source_id=imported_source.id, limit=10)
|
||||
assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files"
|
||||
|
||||
# Verify file blocks are preserved in imported agent
|
||||
imported_file_blocks = imported_agent.memory.file_blocks
|
||||
assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks"
|
||||
|
||||
# Verify file block content
|
||||
for file_block in imported_file_blocks:
|
||||
assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content"
|
||||
assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header"
|
||||
|
||||
# Test that files can be opened on the imported agent
|
||||
if len(imported_files) > 0:
|
||||
test_file = imported_files[0]
|
||||
client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id)
|
||||
|
||||
# Clean up agents and sources
|
||||
client.agents.delete(agent_id=imported_agent_id)
|
||||
client.sources.delete(source_id=imported_source.id)
|
||||
|
||||
@@ -4,11 +4,10 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock, DuplicateFileHandling
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import LettaRequest
|
||||
from letta_client import MessageCreate as ClientMessageCreate
|
||||
@@ -19,6 +18,7 @@ from letta.schemas.enums import FileProcessingStatus, ToolType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from tests.helpers.utils import upload_file_and_wait
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
@@ -72,36 +72,6 @@ def client() -> LettaSDKClient:
|
||||
yield client
|
||||
|
||||
|
||||
def upload_file_and_wait(
|
||||
client: LettaSDKClient,
|
||||
source_id: str,
|
||||
file_path: str,
|
||||
name: Optional[str] = None,
|
||||
max_wait: int = 60,
|
||||
duplicate_handling: DuplicateFileHandling = None,
|
||||
):
|
||||
"""Helper function to upload a file and wait for processing to complete"""
|
||||
with open(file_path, "rb") as f:
|
||||
if duplicate_handling:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
|
||||
else:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
|
||||
|
||||
# Wait for the file to be processed
|
||||
start_time = time.time()
|
||||
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
|
||||
if time.time() - start_time > max_wait:
|
||||
pytest.fail(f"File processing timed out after {max_wait} seconds")
|
||||
time.sleep(1)
|
||||
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
|
||||
print("Waiting for file processing to complete...", file_metadata.processing_status)
|
||||
|
||||
if file_metadata.processing_status == "error":
|
||||
pytest.fail(f"File processing failed: {file_metadata.error_message}")
|
||||
|
||||
return file_metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_state(disable_pinecone, client: LettaSDKClient):
|
||||
open_file_tool = client.tools.list(name="open_files")[0]
|
||||
|
||||
Reference in New Issue
Block a user