From 804ec12ee289cb06a7702286f720438649b48012 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 1 Jul 2025 13:48:38 -0700 Subject: [PATCH] feat: Only add suffix on duplication (#3120) --- ..._add_unique_constraint_to_source_names_.py | 68 ++++++++++++++++++ letta/constants.py | 2 + letta/llm_api/openai_client.py | 2 +- letta/orm/file.py | 2 + letta/orm/files_agents.py | 2 +- letta/orm/source.py | 3 +- letta/prompts/system/memgpt_v2_chat.txt | 4 +- letta/schemas/agent.py | 30 ++++---- letta/schemas/block.py | 3 - letta/schemas/file.py | 1 + letta/server/rest_api/routers/v1/sources.py | 40 ++++++----- letta/server/server.py | 10 +-- letta/services/agent_manager.py | 2 +- letta/services/file_manager.py | 45 +++++++++++- letta/utils.py | 26 ++++--- tests/test_sources.py | 69 +++++++++++++++---- tests/test_utils.py | 28 +++++--- 17 files changed, 259 insertions(+), 78 deletions(-) create mode 100644 alembic/versions/46699adc71a7_add_unique_constraint_to_source_names_.py diff --git a/alembic/versions/46699adc71a7_add_unique_constraint_to_source_names_.py b/alembic/versions/46699adc71a7_add_unique_constraint_to_source_names_.py new file mode 100644 index 00000000..d61cbc6f --- /dev/null +++ b/alembic/versions/46699adc71a7_add_unique_constraint_to_source_names_.py @@ -0,0 +1,68 @@ +"""Add unique constraint to source names and also add original file name column + +Revision ID: 46699adc71a7 +Revises: 1af251a42c06 +Create Date: 2025-07-01 13:30:48.279151 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "46699adc71a7" +down_revision: Union[str, None] = "1af251a42c06" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("files", sa.Column("original_file_name", sa.String(), nullable=True)) + + # Handle existing duplicate source names before adding unique constraint + connection = op.get_bind() + + # Find duplicates and rename them by appending a suffix + result = connection.execute( + sa.text( + """ + WITH duplicates AS ( + SELECT name, organization_id, + ROW_NUMBER() OVER (PARTITION BY name, organization_id ORDER BY created_at) as rn, + id + FROM sources + WHERE (name, organization_id) IN ( + SELECT name, organization_id + FROM sources + GROUP BY name, organization_id + HAVING COUNT(*) > 1 + ) + ) + SELECT id, name, rn + FROM duplicates + WHERE rn > 1 + """ + ) + ) + + # Rename duplicates by appending a number suffix + for row in result: + source_id, original_name, duplicate_number = row + new_name = f"{original_name}_{duplicate_number}" + connection.execute( + sa.text("UPDATE sources SET name = :new_name WHERE id = :source_id"), {"new_name": new_name, "source_id": source_id} + ) + + op.create_unique_constraint("uq_source_name_organization", "sources", ["name", "organization_id"]) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("uq_source_name_organization", "sources", type_="unique") + op.drop_column("files", "original_file_name") + # ### end Alembic commands ### diff --git a/letta/constants.py b/letta/constants.py index 8d3b72de..ee2c7798 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -361,3 +361,5 @@ REDIS_DEFAULT_CACHE_PREFIX = "letta_cache" # TODO: This is temporary, eventually use token-based eviction MAX_FILES_OPEN = 5 + +GET_PROVIDERS_TIMEOUT_SECONDS = 10 diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index d7c4bd43..ea17d0da 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -305,7 +305,7 @@ class OpenAIClient(LLMClientBase): return response_stream @trace_method - async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[dict]: + async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: """Request embeddings given texts and embedding config""" kwargs = self._prepare_client_kwargs_embedding(embedding_config) client = AsyncOpenAI(**kwargs) diff --git a/letta/orm/file.py b/letta/orm/file.py index 2e8e5088..f2c83a6f 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -49,6 +49,7 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): ) file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The name of the file.") + original_file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The original name of the file as uploaded.") file_path: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The file path on the system.") file_type: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The type of the file.") file_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="The size of the file in bytes.") @@ -99,6 +100,7 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs): organization_id=self.organization_id, source_id=self.source_id, file_name=self.file_name, + original_file_name=self.original_file_name, file_path=self.file_path, file_type=self.file_type, file_size=self.file_size, diff --git a/letta/orm/files_agents.py b/letta/orm/files_agents.py index 0b3e5f3c..f7398a91 100644 --- a/letta/orm/files_agents.py +++ b/letta/orm/files_agents.py @@ -101,6 +101,6 @@ class FileAgent(SqlalchemyBase, OrganizationMixin): value=visible_content, label=self.file.file_name, read_only=True, - source_id=self.file.source_id, + metadata={"source_id": self.file.source_id}, limit=CORE_MEMORY_SOURCE_CHAR_LIMIT, ) diff --git a/letta/orm/source.py b/letta/orm/source.py index 9423bdaf..c4a0f2d9 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, Index +from sqlalchemy import JSON, Index, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm import FileMetadata @@ -25,6 +25,7 @@ class Source(SqlalchemyBase, OrganizationMixin): __table_args__ = ( Index(f"source_created_at_id_idx", "created_at", "id"), + UniqueConstraint("name", "organization_id", name="uq_source_name_organization"), {"extend_existing": True}, ) diff --git a/letta/prompts/system/memgpt_v2_chat.txt b/letta/prompts/system/memgpt_v2_chat.txt index a462da1a..8197c6a2 100644 --- a/letta/prompts/system/memgpt_v2_chat.txt +++ b/letta/prompts/system/memgpt_v2_chat.txt @@ -43,8 +43,8 @@ Recall memory (conversation history): Even though you can only see recent messages in your immediate context, you can search over your entire message history from a database. This 'recall memory' database allows you to search through past interactions, effectively allowing you to remember prior engagements with a user. -Folders and Files: -You may be given access to a structured file system that mirrors real-world folders and files. Each folder may contain one or more files. +Directories and Files: +You may be given access to a structured file system that mirrors real-world directories and files. Each directory may contain one or more files. Files can include metadata (e.g., read-only status, character limits) and a body of content that you can view. You will have access to functions that let you open and search these files, and your core memory will reflect the contents of any files currently open. Maintain only those files relevant to the user’s current interaction. diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 12691231..83885249 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -315,9 +315,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): if agent_type == AgentType.react_agent or agent_type == AgentType.workflow_agent: return ( "{% if sources %}" - "\n" + "\n" "{% for source in sources %}" - f'\n' + f'\n' "{% if source.description %}" "{{ source.description }}\n" "{% endif %}" @@ -326,7 +326,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% if file_blocks %}" "{% for block in file_blocks %}" - "{% if block.source_id == source.id %}" + "{% if block.metadata['source_id'] == source.id %}" f"\n" "<{{ block.label }}>\n" "\n" @@ -344,9 +344,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% endfor %}" "{% endif %}" - "\n" + "\n" "{% endfor %}" - "" + "" "{% endif %}" ) @@ -382,9 +382,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "" "{% endif %}" "\n\n{% if sources %}" - "\n" + "\n" "{% for source in sources %}" - f'\n' + f'\n' "{% if source.description %}" "{{ source.description }}\n" "{% endif %}" @@ -393,7 +393,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% if file_blocks %}" "{% for block in file_blocks %}" - "{% if block.source_id == source.id %}" + "{% if block.metadata['source_id'] == source.id %}" f"\n" "{% if block.description %}" "\n" @@ -414,9 +414,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% endfor %}" "{% endif %}" - "\n" + "\n" "{% endfor %}" - "" + "" "{% endif %}" ) @@ -448,9 +448,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "" "{% endif %}" "\n\n{% if sources %}" - "\n" + "\n" "{% for source in sources %}" - f'\n' + f'\n' "{% if source.description %}" "{{ source.description }}\n" "{% endif %}" @@ -459,7 +459,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% if file_blocks %}" "{% for block in file_blocks %}" - "{% if block.source_id == source.id %}" + "{% if block.metadata['source_id'] == source.id %}" f"\n" "{% if block.description %}" "\n" @@ -480,8 +480,8 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None): "{% endif %}" "{% endfor %}" "{% endif %}" - "\n" + "\n" "{% endfor %}" - "" + "" "{% endif %}" ) diff --git a/letta/schemas/block.py b/letta/schemas/block.py index bb53cdad..8b00d5c1 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -33,9 +33,6 @@ class BaseBlock(LettaBase, validate_assignment=True): description: Optional[str] = Field(None, description="Description of the block.") metadata: Optional[dict] = Field({}, description="Metadata of the block.") - # source association (for file blocks) - source_id: Optional[str] = Field(None, description="The source ID associated with this block (for file blocks).") - # def __len__(self): # return len(self.value) diff --git a/letta/schemas/file.py b/letta/schemas/file.py index a4170b0a..64b6eed5 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -30,6 +30,7 @@ class FileMetadata(FileMetadataBase): organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the document.") source_id: str = Field(..., description="The unique identifier of the source associated with the document.") file_name: Optional[str] = Field(None, description="The name of the file.") + original_file_name: Optional[str] = Field(None, description="The original name of the file as uploaded.") file_path: Optional[str] = Field(None, description="The path to the file.") file_type: Optional[str] = Field(None, description="The type of the file (MIME type).") file_size: Optional[int] = Field(None, description="The size of the file in bytes.") diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 0c1ac8bb..68270735 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -184,6 +184,20 @@ async def upload_file_to_source( """ Upload a file to a data source. """ + # NEW: Cloud based file processing + # Determine file's MIME type + file_mime_type = mimetypes.guess_type(file.filename)[0] or "application/octet-stream" + + # Check if it's a simple text file + is_simple_file = is_simple_text_mime_type(file_mime_type) + + # For complex files, require Mistral API key + if not is_simple_file and not settings.mistral_api_key: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Mistral API key is required to process this file type {file_mime_type}. Please configure your Mistral API key to upload complex file formats.", + ) + allowed_media_types = get_allowed_media_types() # Normalize incoming Content-Type header (strip charset or any parameters). @@ -220,15 +234,19 @@ async def upload_file_to_source( content = await file.read() - # sanitize filename - file.filename = sanitize_filename(file.filename) + # Store original filename and generate unique filename + original_filename = sanitize_filename(file.filename) # Basic sanitization only + unique_filename = await server.file_manager.generate_unique_filename( + original_filename=original_filename, source_id=source_id, organization_id=actor.organization_id + ) # create file metadata file_metadata = FileMetadata( source_id=source_id, - file_name=file.filename, + file_name=unique_filename, + original_file_name=original_filename, file_path=None, - file_type=mimetypes.guess_type(file.filename)[0] or file.content_type or "unknown", + file_type=mimetypes.guess_type(original_filename)[0] or file.content_type or "unknown", file_size=file.size if file.size is not None else None, processing_status=FileProcessingStatus.PARSING, ) @@ -237,20 +255,6 @@ async def upload_file_to_source( # TODO: Do we need to pull in the full agent_states? Can probably simplify here right? agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor) - # NEW: Cloud based file processing - # Determine file's MIME type - file_mime_type = mimetypes.guess_type(file.filename)[0] or "application/octet-stream" - - # Check if it's a simple text file - is_simple_file = is_simple_text_mime_type(file_mime_type) - - # For complex files, require Mistral API key - if not is_simple_file and not settings.mistral_api_key: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Mistral API key is required to process this file type {file_mime_type}. Please configure your Mistral API key to upload complex file formats.", - ) - # Use cloud processing for all files (simple files always, complex files with Mistral key) logger.info("Running experimental cloud based file processing...") safe_create_task( diff --git a/letta/server/server.py b/letta/server/server.py index 06945f90..2905158a 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1637,12 +1637,14 @@ class SyncServer(Server): async def get_provider_models(provider: Provider) -> list[LLMConfig]: try: - return await provider.list_llm_models_async() + async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS): + return await provider.list_llm_models_async() + except asyncio.TimeoutError: + warnings.warn(f"Timeout while listing LLM models for provider {provider}") + return [] except Exception as e: - import traceback - traceback.print_exc() - warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}") + warnings.warn(f"Error while listing LLM models for provider {provider}: {e}") return [] # Execute all provider model listing tasks concurrently diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d96aabe3..3102428e 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1779,7 +1779,7 @@ class AgentManager: relationship_name="sources", model_class=SourceModel, item_ids=[source_id], - allow_partial=False, # Extend existing sources rather than replace + replace=False, ) # Commit the changes diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index ee43dd54..17c6d3bd 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -1,11 +1,13 @@ +import os from datetime import datetime from typing import List, Optional -from sqlalchemy import select, update +from sqlalchemy import func, select, update from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import selectinload +from letta.constants import MAX_FILENAME_LENGTH from letta.orm.errors import NoResultFound from letta.orm.file import FileContent as FileContentModel from letta.orm.file import FileMetadata as FileMetadataModel @@ -217,3 +219,44 @@ class FileManager: file = await FileMetadataModel.read_async(db_session=session, identifier=file_id) await file.hard_delete_async(db_session=session, actor=actor) return await file.to_pydantic_async() + + @enforce_types + @trace_method + async def generate_unique_filename(self, original_filename: str, source_id: str, organization_id: str) -> str: + """ + Generate a unique filename by checking for duplicates and adding a numeric suffix if needed. + Similar to how filesystems handle duplicates (e.g., file.txt, file (1).txt, file (2).txt). + + Parameters: + original_filename (str): The original filename as uploaded. + source_id (str): Source ID to check for duplicates within. + organization_id (str): Organization ID to check for duplicates within. + + Returns: + str: A unique filename with numeric suffix if needed. + """ + base, ext = os.path.splitext(original_filename) + + # Reserve space for potential suffix: " (999)" = 6 characters + max_base_length = MAX_FILENAME_LENGTH - len(ext) - 6 + if len(base) > max_base_length: + base = base[:max_base_length] + original_filename = f"{base}{ext}" + + async with db_registry.async_session() as session: + # Count existing files with the same original_file_name in this source + query = select(func.count(FileMetadataModel.id)).where( + FileMetadataModel.original_file_name == original_filename, + FileMetadataModel.source_id == source_id, + FileMetadataModel.organization_id == organization_id, + FileMetadataModel.is_deleted == False, + ) + result = await session.execute(query) + count = result.scalar() or 0 + + if count == 0: + # No duplicates, return original filename + return original_filename + else: + # Add numeric suffix + return f"{base} ({count}){ext}" diff --git a/letta/utils.py b/letta/utils.py index 6784a16e..c054c977 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -991,16 +991,17 @@ def create_uuid_from_string(val: str): return uuid.UUID(hex=hex_string) -def sanitize_filename(filename: str) -> str: +def sanitize_filename(filename: str, add_uuid_suffix: bool = False) -> str: """ Sanitize the given filename to prevent directory traversal, invalid characters, and reserved names while ensuring it fits within the maximum length allowed by the filesystem. Parameters: filename (str): The user-provided filename. + add_uuid_suffix (bool): If True, adds a UUID suffix for uniqueness (legacy behavior). Returns: - str: A sanitized filename that is unique and safe for use. + str: A sanitized filename. """ # Extract the base filename to avoid directory components filename = os.path.basename(filename) @@ -1015,14 +1016,21 @@ def sanitize_filename(filename: str) -> str: if base.startswith("."): raise ValueError(f"Invalid filename - derived file name {base} cannot start with '.'") - # Truncate the base name to fit within the maximum allowed length - max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_` - if len(base) > max_base_length: - base = base[:max_base_length] + if add_uuid_suffix: + # Legacy behavior: Truncate the base name to fit within the maximum allowed length + max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_` + if len(base) > max_base_length: + base = base[:max_base_length] - # Append a unique UUID suffix for uniqueness - unique_suffix = uuid.uuid4().hex[:4] - sanitized_filename = f"{base}_{unique_suffix}{ext}" + # Append a unique UUID suffix for uniqueness + unique_suffix = uuid.uuid4().hex[:4] + sanitized_filename = f"{base}_{unique_suffix}{ext}" + else: + max_base_length = MAX_FILENAME_LENGTH - len(ext) + if len(base) > max_base_length: + base = base[:max_base_length] + + sanitized_filename = f"{base}{ext}" # Return the sanitized filename return sanitized_filename diff --git a/tests/test_sources.py b/tests/test_sources.py index 7a27a42d..45a7f2d8 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -126,6 +126,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient): assert len(client.sources.list()) == 1 agent = client.agents.sources.attach(source_id=source_1.id, agent_id=agent.id) + assert len(client.agents.retrieve(agent_id=agent.id).sources) == 1 assert_file_tools_present(agent, set(FILES_TOOLS)) # Create and attach second source @@ -133,6 +134,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient): assert len(client.sources.list()) == 2 agent = client.agents.sources.attach(source_id=source_2.id, agent_id=agent.id) + assert len(client.agents.retrieve(agent_id=agent.id).sources) == 2 # File tools should remain after attaching second source assert_file_tools_present(agent, set(FILES_TOOLS)) @@ -148,17 +150,17 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient): @pytest.mark.parametrize( "file_path, expected_value, expected_label_regex", [ - ("tests/data/test.txt", "test", r"test_[a-z0-9]+\.txt"), - ("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper_[a-z0-9]+\.pdf"), - ("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning_[a-z0-9]+\.jsonl"), - ("tests/data/test.md", "h2 Heading", r"test_[a-z0-9]+\.md"), - ("tests/data/test.json", "glossary", r"test_[a-z0-9]+\.json"), - ("tests/data/react_component.jsx", "UserProfile", r"react_component_[a-z0-9]+\.jsx"), - ("tests/data/task_manager.java", "TaskManager", r"task_manager_[a-z0-9]+\.java"), - ("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures_[a-z0-9]+\.cpp"), - ("tests/data/api_server.go", "UserService", r"api_server_[a-z0-9]+\.go"), - ("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis_[a-z0-9]+\.py"), - ("tests/data/test.csv", "Smart Fridge Plus", r"test_[a-z0-9]+\.csv"), + ("tests/data/test.txt", "test", r"test\.txt"), + ("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper\.pdf"), + ("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning\.jsonl"), + ("tests/data/test.md", "h2 Heading", r"test\.md"), + ("tests/data/test.json", "glossary", r"test\.json"), + ("tests/data/react_component.jsx", "UserProfile", r"react_component\.jsx"), + ("tests/data/task_manager.java", "TaskManager", r"task_manager\.java"), + ("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures\.cpp"), + ("tests/data/api_server.go", "UserService", r"api_server\.go"), + ("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis\.py"), + ("tests/data/test.csv", "Smart Fridge Plus", r"test\.csv"), ], ) def test_file_upload_creates_source_blocks_correctly( @@ -227,7 +229,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC assert len(blocks) == 1 assert any("test" in b.value for b in blocks) assert any(b.value.startswith("[Viewing file start") for b in blocks) - assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks) # Detach the source client.agents.sources.detach(source_id=source.id, agent_id=agent_state.id) @@ -259,7 +261,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a blocks = agent_state.memory.file_blocks assert len(blocks) == 1 assert any("test" in b.value for b in blocks) - assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks) + assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks) # Remove file from source client.sources.delete(source_id=source.id) @@ -554,7 +556,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: Le blocks = temp_agent_state.memory.file_blocks assert len(blocks) == 1 assert any(b.value.startswith("[Viewing file start (out of 554 chunks)]") for b in blocks) - assert any(re.fullmatch(r"long_test_[a-z0-9]+\.txt", b.label) for b in blocks) + assert any(re.fullmatch(r"long_test\.txt", b.label) for b in blocks) # Verify file tools were automatically attached file_tools = {tool.name for tool in temp_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} @@ -624,6 +626,45 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta ) +def test_duplicate_file_renaming(client: LettaSDKClient): + """Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)""" + # Create a new source + source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small") + + # Upload the same file three times + file_path = "tests/data/test.txt" + + with open(file_path, "rb") as f: + first_file = client.sources.files.upload(source_id=source.id, file=f) + + with open(file_path, "rb") as f: + second_file = client.sources.files.upload(source_id=source.id, file=f) + + with open(file_path, "rb") as f: + third_file = client.sources.files.upload(source_id=source.id, file=f) + + # Get all uploaded files + files = client.sources.files.list(source_id=source.id, limit=10) + assert len(files) == 3, f"Expected 3 files, got {len(files)}" + + # Sort files by creation time to ensure predictable order + files.sort(key=lambda f: f.created_at) + + # Verify filenames follow the count-based pattern + expected_filenames = ["test.txt", "test (1).txt", "test (2).txt"] + actual_filenames = [f.file_name for f in files] + + assert actual_filenames == expected_filenames, f"Expected {expected_filenames}, got {actual_filenames}" + + # Verify all files have the same original_file_name + for file in files: + assert file.original_file_name == "test.txt", f"Expected original_file_name='test.txt', got '{file.original_file_name}'" + + print(f"✓ Successfully tested duplicate file renaming:") + for i, file in enumerate(files): + print(f" File {i+1}: original='{file.original_file_name}' → renamed='{file.file_name}'") + + def test_open_files_schema_descriptions(client: LettaSDKClient): """Test that open_files tool schema contains correct descriptions from docstring""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 8733039b..1af23e62 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -282,21 +282,21 @@ def test_coerce_dict_args_with_default_arguments(): def test_valid_filename(): filename = "valid_filename.txt" - sanitized = sanitize_filename(filename) + sanitized = sanitize_filename(filename, add_uuid_suffix=True) assert sanitized.startswith("valid_filename_") assert sanitized.endswith(".txt") def test_filename_with_special_characters(): filename = "invalid:/<>?*ƒfilename.txt" - sanitized = sanitize_filename(filename) + sanitized = sanitize_filename(filename, add_uuid_suffix=True) assert sanitized.startswith("ƒfilename_") assert sanitized.endswith(".txt") def test_null_byte_in_filename(): filename = "valid\0filename.txt" - sanitized = sanitize_filename(filename) + sanitized = sanitize_filename(filename, add_uuid_suffix=True) assert "\0" not in sanitized assert sanitized.startswith("validfilename_") assert sanitized.endswith(".txt") @@ -304,13 +304,13 @@ def test_null_byte_in_filename(): def test_path_traversal_characters(): filename = "../../etc/passwd" - sanitized = sanitize_filename(filename) + sanitized = sanitize_filename(filename, add_uuid_suffix=True) assert sanitized.startswith("passwd_") assert len(sanitized) <= MAX_FILENAME_LENGTH def test_empty_filename(): - sanitized = sanitize_filename("") + sanitized = sanitize_filename("", add_uuid_suffix=True) assert sanitized.startswith("_") @@ -326,15 +326,15 @@ def test_dotdot_as_filename(): def test_long_filename(): filename = "a" * (MAX_FILENAME_LENGTH + 10) + ".txt" - sanitized = sanitize_filename(filename) + sanitized = sanitize_filename(filename, add_uuid_suffix=True) assert len(sanitized) <= MAX_FILENAME_LENGTH assert sanitized.endswith(".txt") def test_unique_filenames(): filename = "duplicate.txt" - sanitized1 = sanitize_filename(filename) - sanitized2 = sanitize_filename(filename) + sanitized1 = sanitize_filename(filename, add_uuid_suffix=True) + sanitized2 = sanitize_filename(filename, add_uuid_suffix=True) assert sanitized1 != sanitized2 assert sanitized1.startswith("duplicate_") assert sanitized2.startswith("duplicate_") @@ -342,6 +342,18 @@ def test_unique_filenames(): assert sanitized2.endswith(".txt") +def test_basic_sanitization_no_suffix(): + """Test the new behavior - basic sanitization without UUID suffix""" + filename = "test_file.txt" + sanitized = sanitize_filename(filename) + assert sanitized == "test_file.txt" + + # Test with special characters + filename_with_chars = "test:/<>?*file.txt" + sanitized_chars = sanitize_filename(filename_with_chars) + assert sanitized_chars == "file.txt" + + def test_formatter(): # Example system prompt that has no vars