From 31a96179659061cd2e056e1de8762ff65380467e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 10 Jul 2025 13:16:01 -0700 Subject: [PATCH] feat: Add duplication handling behavior (#3273) --- letta/schemas/enums.py | 8 ++ letta/server/rest_api/routers/v1/sources.py | 28 ++++- letta/services/file_manager.py | 54 +++++++- tests/test_managers.py | 131 ++++++++++++++++++++ 4 files changed, 214 insertions(+), 7 deletions(-) diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index f4cd35f8..cbe922ca 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -139,3 +139,11 @@ class MCPServerType(str, Enum): SSE = "sse" STDIO = "stdio" STREAMABLE_HTTP = "streamable_http" + + +class DuplicateFileHandling(str, Enum): + """How to handle duplicate filenames when uploading files""" + + SKIP = "skip" # skip files with duplicate names + ERROR = "error" # error when duplicate names are encountered + SUFFIX = "suffix" # add numeric suffix to make names unique (default behavior) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 9daecbdc..e7ab5370 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -19,7 +19,7 @@ from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import FileProcessingStatus +from letta.schemas.enums import DuplicateFileHandling, FileProcessingStatus from letta.schemas.file import FileMetadata from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate @@ -208,6 +208,7 @@ async def delete_source( async def upload_file_to_source( file: UploadFile, source_id: str, + duplicate_handling: DuplicateFileHandling = Query(DuplicateFileHandling.SUFFIX, description="How to handle duplicate filenames"), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): @@ -264,8 +265,31 @@ async def upload_file_to_source( content = await file.read() - # Store original filename and generate unique filename + # Store original filename and handle duplicate logic original_filename = sanitize_filename(file.filename) # Basic sanitization only + + # Check if duplicate exists + existing_file = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=source_id, actor=actor + ) + + if existing_file: + # Duplicate found, handle based on strategy + if duplicate_handling == DuplicateFileHandling.ERROR: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail=f"File '{original_filename}' already exists in source '{source.name}'" + ) + elif duplicate_handling == DuplicateFileHandling.SKIP: + # Return existing file metadata with custom header to indicate it was skipped + from fastapi import Response + + response = Response( + content=existing_file.model_dump_json(), media_type="application/json", headers={"X-Upload-Result": "skipped"} + ) + return response + # For SUFFIX, continue to generate unique filename + + # Generate unique filename (adds suffix if needed) unique_filename = await server.file_manager.generate_unique_filename( original_filename=original_filename, source=source, organization_id=actor.organization_id ) diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index f4a84fb3..530fa3e1 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -22,6 +22,15 @@ from letta.server.db import db_registry from letta.utils import enforce_types +class DuplicateFileError(Exception): + """Raised when a duplicate file is encountered and error handling is specified""" + + def __init__(self, filename: str, source_name: str): + self.filename = filename + self.source_name = source_name + super().__init__(f"File '{filename}' already exists in source '{source_name}'") + + class FileManager: """Manager class to handle business logic related to files.""" @@ -237,16 +246,16 @@ class FileManager: @trace_method async def generate_unique_filename(self, original_filename: str, source: PydanticSource, organization_id: str) -> str: """ - Generate a unique filename by checking for duplicates and adding a numeric suffix if needed. - Similar to how filesystems handle duplicates (e.g., file.txt, file (1).txt, file (2).txt). + Generate a unique filename by adding a numeric suffix if duplicates exist. + Always returns a unique filename - does not handle duplicate policies. Parameters: original_filename (str): The original filename as uploaded. - source_id (str): Source ID to check for duplicates within. + source (PydanticSource): Source to check for duplicates within. organization_id (str): Organization ID to check for duplicates within. Returns: - str: A unique filename with numeric suffix if needed. + str: A unique filename with source.name prefix and numeric suffix if needed. """ base, ext = os.path.splitext(original_filename) @@ -271,9 +280,44 @@ class FileManager: # No duplicates, return original filename with source.name return f"{source.name}/{original_filename}" else: - # Add numeric suffix + # Add numeric suffix to make unique return f"{source.name}/{base}_({count}){ext}" + @enforce_types + @trace_method + async def get_file_by_original_name_and_source( + self, original_filename: str, source_id: str, actor: PydanticUser + ) -> Optional[PydanticFileMetadata]: + """ + Get a file by its original filename and source ID. + + Parameters: + original_filename (str): The original filename to search for. + source_id (str): The source ID to search within. + actor (PydanticUser): The actor performing the request. + + Returns: + Optional[PydanticFileMetadata]: The file metadata if found, None otherwise. + """ + async with db_registry.async_session() as session: + query = ( + select(FileMetadataModel) + .where( + FileMetadataModel.original_file_name == original_filename, + FileMetadataModel.source_id == source_id, + FileMetadataModel.organization_id == actor.organization_id, + FileMetadataModel.is_deleted == False, + ) + .limit(1) + ) + + result = await session.execute(query) + file_orm = result.scalar_one_or_none() + + if file_orm: + return await file_orm.to_pydantic_async() + return None + @enforce_types @trace_method async def get_organization_sources_metadata(self, actor: PydanticUser) -> OrganizationSourcesStats: diff --git a/tests/test_managers.py b/tests/test_managers.py index be30c032..cb43063e 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5117,6 +5117,137 @@ async def test_delete_cascades_to_content(server, default_user, default_source, assert await _count_file_content_rows(async_session, created.id) == 0 +@pytest.mark.asyncio +async def test_get_file_by_original_name_and_source_found(server: SyncServer, default_user, default_source): + """Test retrieving a file by original filename and source when it exists.""" + original_filename = "test_original_file.txt" + file_metadata = PydanticFileMetadata( + file_name="some_generated_name.txt", + original_file_name=original_filename, + file_path="/path/to/test_file.txt", + file_type="text/plain", + file_size=1024, + source_id=default_source.id, + ) + created_file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user) + + # Retrieve the file by original name and source + retrieved_file = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=default_source.id, actor=default_user + ) + + # Assertions to verify the retrieved file matches the created one + assert retrieved_file is not None + assert retrieved_file.id == created_file.id + assert retrieved_file.original_file_name == original_filename + assert retrieved_file.source_id == default_source.id + + +@pytest.mark.asyncio +async def test_get_file_by_original_name_and_source_not_found(server: SyncServer, default_user, default_source): + """Test retrieving a file by original filename and source when it doesn't exist.""" + non_existent_filename = "does_not_exist.txt" + + # Try to retrieve a non-existent file + retrieved_file = await server.file_manager.get_file_by_original_name_and_source( + original_filename=non_existent_filename, source_id=default_source.id, actor=default_user + ) + + # Should return None for non-existent file + assert retrieved_file is None + + +@pytest.mark.asyncio +async def test_get_file_by_original_name_and_source_different_sources(server: SyncServer, default_user, default_source): + """Test that files with same original name in different sources are handled correctly.""" + from letta.schemas.source import Source as PydanticSource + + # Create a second source + second_source_pydantic = PydanticSource( + name="second_test_source", + description="This is a test source.", + metadata={"type": "test"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + second_source = await server.source_manager.create_source(source=second_source_pydantic, actor=default_user) + + original_filename = "shared_filename.txt" + + # Create file in first source + file_metadata_1 = PydanticFileMetadata( + file_name="file_in_source_1.txt", + original_file_name=original_filename, + file_path="/path/to/file1.txt", + file_type="text/plain", + file_size=1024, + source_id=default_source.id, + ) + created_file_1 = await server.file_manager.create_file(file_metadata=file_metadata_1, actor=default_user) + + # Create file with same original name in second source + file_metadata_2 = PydanticFileMetadata( + file_name="file_in_source_2.txt", + original_file_name=original_filename, + file_path="/path/to/file2.txt", + file_type="text/plain", + file_size=2048, + source_id=second_source.id, + ) + created_file_2 = await server.file_manager.create_file(file_metadata=file_metadata_2, actor=default_user) + + # Retrieve file from first source + retrieved_file_1 = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=default_source.id, actor=default_user + ) + + # Retrieve file from second source + retrieved_file_2 = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=second_source.id, actor=default_user + ) + + # Should retrieve different files + assert retrieved_file_1 is not None + assert retrieved_file_2 is not None + assert retrieved_file_1.id == created_file_1.id + assert retrieved_file_2.id == created_file_2.id + assert retrieved_file_1.id != retrieved_file_2.id + assert retrieved_file_1.source_id == default_source.id + assert retrieved_file_2.source_id == second_source.id + + +@pytest.mark.asyncio +async def test_get_file_by_original_name_and_source_ignores_deleted(server: SyncServer, default_user, default_source): + """Test that deleted files are ignored when searching by original name and source.""" + original_filename = "to_be_deleted.txt" + file_metadata = PydanticFileMetadata( + file_name="deletable_file.txt", + original_file_name=original_filename, + file_path="/path/to/deletable.txt", + file_type="text/plain", + file_size=512, + source_id=default_source.id, + ) + created_file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user) + + # Verify file can be found before deletion + retrieved_file = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=default_source.id, actor=default_user + ) + assert retrieved_file is not None + assert retrieved_file.id == created_file.id + + # Delete the file + await server.file_manager.delete_file(created_file.id, actor=default_user) + + # Try to retrieve the deleted file + retrieved_file_after_delete = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=default_source.id, actor=default_user + ) + + # Should return None for deleted file + assert retrieved_file_after_delete is None + + @pytest.mark.asyncio async def test_list_files(server: SyncServer, default_user, default_source): """Test listing files with pagination."""