feat: Only add suffix on duplication (#3120)

This commit is contained in:
Matthew Zhou
2025-07-01 13:48:38 -07:00
committed by GitHub
parent d064077f4d
commit efca9d8ea0
17 changed files with 259 additions and 78 deletions

View File

@@ -0,0 +1,68 @@
"""Add unique constraint to source names and also add original file name column
Revision ID: 46699adc71a7
Revises: 1af251a42c06
Create Date: 2025-07-01 13:30:48.279151
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "46699adc71a7"
down_revision: Union[str, None] = "1af251a42c06"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("original_file_name", sa.String(), nullable=True))
# Handle existing duplicate source names before adding unique constraint
connection = op.get_bind()
# Find duplicates and rename them by appending a suffix
result = connection.execute(
sa.text(
"""
WITH duplicates AS (
SELECT name, organization_id,
ROW_NUMBER() OVER (PARTITION BY name, organization_id ORDER BY created_at) as rn,
id
FROM sources
WHERE (name, organization_id) IN (
SELECT name, organization_id
FROM sources
GROUP BY name, organization_id
HAVING COUNT(*) > 1
)
)
SELECT id, name, rn
FROM duplicates
WHERE rn > 1
"""
)
)
# Rename duplicates by appending a number suffix
for row in result:
source_id, original_name, duplicate_number = row
new_name = f"{original_name}_{duplicate_number}"
connection.execute(
sa.text("UPDATE sources SET name = :new_name WHERE id = :source_id"), {"new_name": new_name, "source_id": source_id}
)
op.create_unique_constraint("uq_source_name_organization", "sources", ["name", "organization_id"])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("uq_source_name_organization", "sources", type_="unique")
op.drop_column("files", "original_file_name")
# ### end Alembic commands ###

View File

@@ -361,3 +361,5 @@ REDIS_DEFAULT_CACHE_PREFIX = "letta_cache"
# TODO: This is temporary, eventually use token-based eviction
MAX_FILES_OPEN = 5
GET_PROVIDERS_TIMEOUT_SECONDS = 10

View File

@@ -305,7 +305,7 @@ class OpenAIClient(LLMClientBase):
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[dict]:
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
kwargs = self._prepare_client_kwargs_embedding(embedding_config)
client = AsyncOpenAI(**kwargs)

View File

@@ -49,6 +49,7 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
)
file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The name of the file.")
original_file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The original name of the file as uploaded.")
file_path: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The file path on the system.")
file_type: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The type of the file.")
file_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="The size of the file in bytes.")
@@ -99,6 +100,7 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
organization_id=self.organization_id,
source_id=self.source_id,
file_name=self.file_name,
original_file_name=self.original_file_name,
file_path=self.file_path,
file_type=self.file_type,
file_size=self.file_size,

View File

@@ -101,6 +101,6 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
value=visible_content,
label=self.file.file_name,
read_only=True,
source_id=self.file.source_id,
metadata={"source_id": self.file.source_id},
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
)

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, Index
from sqlalchemy import JSON, Index, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm import FileMetadata
@@ -25,6 +25,7 @@ class Source(SqlalchemyBase, OrganizationMixin):
__table_args__ = (
Index(f"source_created_at_id_idx", "created_at", "id"),
UniqueConstraint("name", "organization_id", name="uq_source_name_organization"),
{"extend_existing": True},
)

View File

@@ -43,8 +43,8 @@ Recall memory (conversation history):
Even though you can only see recent messages in your immediate context, you can search over your entire message history from a database.
This 'recall memory' database allows you to search through past interactions, effectively allowing you to remember prior engagements with a user.
Folders and Files:
You may be given access to a structured file system that mirrors real-world folders and files. Each folder may contain one or more files.
Directories and Files:
You may be given access to a structured file system that mirrors real-world directories and files. Each directory may contain one or more files.
Files can include metadata (e.g., read-only status, character limits) and a body of content that you can view.
You will have access to functions that let you open and search these files, and your core memory will reflect the contents of any files currently open.
Maintain only those files relevant to the users current interaction.

View File

@@ -315,9 +315,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
if agent_type == AgentType.react_agent or agent_type == AgentType.workflow_agent:
return (
"{% if sources %}"
"<folders>\n"
"<directories>\n"
"{% for source in sources %}"
f'<folder name="{{{{ source.name }}}}">\n'
f'<directory name="{{{{ source.name }}}}">\n'
"{% if source.description %}"
"<description>{{ source.description }}</description>\n"
"{% endif %}"
@@ -326,7 +326,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% if file_blocks %}"
"{% for block in file_blocks %}"
"{% if block.source_id == source.id %}"
"{% if block.metadata['source_id'] == source.id %}"
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\">\n"
"<{{ block.label }}>\n"
"<description>\n"
@@ -344,9 +344,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"</folder>\n"
"</directory>\n"
"{% endfor %}"
"</folders>"
"</directories>"
"{% endif %}"
)
@@ -382,9 +382,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"</tool_usage_rules>"
"{% endif %}"
"\n\n{% if sources %}"
"<folders>\n"
"<directories>\n"
"{% for source in sources %}"
f'<folder name="{{{{ source.name }}}}">\n'
f'<directory name="{{{{ source.name }}}}">\n'
"{% if source.description %}"
"<description>{{ source.description }}</description>\n"
"{% endif %}"
@@ -393,7 +393,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% if file_blocks %}"
"{% for block in file_blocks %}"
"{% if block.source_id == source.id %}"
"{% if block.metadata['source_id'] == source.id %}"
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\" name=\"{{{{ block.label }}}}\">\n"
"{% if block.description %}"
"<description>\n"
@@ -414,9 +414,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"</folder>\n"
"</directory>\n"
"{% endfor %}"
"</folders>"
"</directories>"
"{% endif %}"
)
@@ -448,9 +448,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"</tool_usage_rules>"
"{% endif %}"
"\n\n{% if sources %}"
"<folders>\n"
"<directories>\n"
"{% for source in sources %}"
f'<folder name="{{{{ source.name }}}}">\n'
f'<directory name="{{{{ source.name }}}}">\n'
"{% if source.description %}"
"<description>{{ source.description }}</description>\n"
"{% endif %}"
@@ -459,7 +459,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% if file_blocks %}"
"{% for block in file_blocks %}"
"{% if block.source_id == source.id %}"
"{% if block.metadata['source_id'] == source.id %}"
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\" name=\"{{{{ block.label }}}}\">\n"
"{% if block.description %}"
"<description>\n"
@@ -480,8 +480,8 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"</folder>\n"
"</directory>\n"
"{% endfor %}"
"</folders>"
"</directories>"
"{% endif %}"
)

View File

@@ -33,9 +33,6 @@ class BaseBlock(LettaBase, validate_assignment=True):
description: Optional[str] = Field(None, description="Description of the block.")
metadata: Optional[dict] = Field({}, description="Metadata of the block.")
# source association (for file blocks)
source_id: Optional[str] = Field(None, description="The source ID associated with this block (for file blocks).")
# def __len__(self):
# return len(self.value)

View File

@@ -30,6 +30,7 @@ class FileMetadata(FileMetadataBase):
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the document.")
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
file_name: Optional[str] = Field(None, description="The name of the file.")
original_file_name: Optional[str] = Field(None, description="The original name of the file as uploaded.")
file_path: Optional[str] = Field(None, description="The path to the file.")
file_type: Optional[str] = Field(None, description="The type of the file (MIME type).")
file_size: Optional[int] = Field(None, description="The size of the file in bytes.")

View File

@@ -184,6 +184,20 @@ async def upload_file_to_source(
"""
Upload a file to a data source.
"""
# NEW: Cloud based file processing
# Determine file's MIME type
file_mime_type = mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
# Check if it's a simple text file
is_simple_file = is_simple_text_mime_type(file_mime_type)
# For complex files, require Mistral API key
if not is_simple_file and not settings.mistral_api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Mistral API key is required to process this file type {file_mime_type}. Please configure your Mistral API key to upload complex file formats.",
)
allowed_media_types = get_allowed_media_types()
# Normalize incoming Content-Type header (strip charset or any parameters).
@@ -220,15 +234,19 @@ async def upload_file_to_source(
content = await file.read()
# sanitize filename
file.filename = sanitize_filename(file.filename)
# Store original filename and generate unique filename
original_filename = sanitize_filename(file.filename) # Basic sanitization only
unique_filename = await server.file_manager.generate_unique_filename(
original_filename=original_filename, source_id=source_id, organization_id=actor.organization_id
)
# create file metadata
file_metadata = FileMetadata(
source_id=source_id,
file_name=file.filename,
file_name=unique_filename,
original_file_name=original_filename,
file_path=None,
file_type=mimetypes.guess_type(file.filename)[0] or file.content_type or "unknown",
file_type=mimetypes.guess_type(original_filename)[0] or file.content_type or "unknown",
file_size=file.size if file.size is not None else None,
processing_status=FileProcessingStatus.PARSING,
)
@@ -237,20 +255,6 @@ async def upload_file_to_source(
# TODO: Do we need to pull in the full agent_states? Can probably simplify here right?
agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
# NEW: Cloud based file processing
# Determine file's MIME type
file_mime_type = mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
# Check if it's a simple text file
is_simple_file = is_simple_text_mime_type(file_mime_type)
# For complex files, require Mistral API key
if not is_simple_file and not settings.mistral_api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Mistral API key is required to process this file type {file_mime_type}. Please configure your Mistral API key to upload complex file formats.",
)
# Use cloud processing for all files (simple files always, complex files with Mistral key)
logger.info("Running experimental cloud based file processing...")
safe_create_task(

View File

@@ -1637,12 +1637,14 @@ class SyncServer(Server):
async def get_provider_models(provider: Provider) -> list[LLMConfig]:
try:
return await provider.list_llm_models_async()
async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS):
return await provider.list_llm_models_async()
except asyncio.TimeoutError:
warnings.warn(f"Timeout while listing LLM models for provider {provider}")
return []
except Exception as e:
import traceback
traceback.print_exc()
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
warnings.warn(f"Error while listing LLM models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently

View File

@@ -1779,7 +1779,7 @@ class AgentManager:
relationship_name="sources",
model_class=SourceModel,
item_ids=[source_id],
allow_partial=False, # Extend existing sources rather than replace
replace=False,
)
# Commit the changes

View File

@@ -1,11 +1,13 @@
import os
from datetime import datetime
from typing import List, Optional
from sqlalchemy import select, update
from sqlalchemy import func, select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from letta.constants import MAX_FILENAME_LENGTH
from letta.orm.errors import NoResultFound
from letta.orm.file import FileContent as FileContentModel
from letta.orm.file import FileMetadata as FileMetadataModel
@@ -217,3 +219,44 @@ class FileManager:
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
await file.hard_delete_async(db_session=session, actor=actor)
return await file.to_pydantic_async()
@enforce_types
@trace_method
async def generate_unique_filename(self, original_filename: str, source_id: str, organization_id: str) -> str:
"""
Generate a unique filename by checking for duplicates and adding a numeric suffix if needed.
Similar to how filesystems handle duplicates (e.g., file.txt, file (1).txt, file (2).txt).
Parameters:
original_filename (str): The original filename as uploaded.
source_id (str): Source ID to check for duplicates within.
organization_id (str): Organization ID to check for duplicates within.
Returns:
str: A unique filename with numeric suffix if needed.
"""
base, ext = os.path.splitext(original_filename)
# Reserve space for potential suffix: " (999)" = 6 characters
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 6
if len(base) > max_base_length:
base = base[:max_base_length]
original_filename = f"{base}{ext}"
async with db_registry.async_session() as session:
# Count existing files with the same original_file_name in this source
query = select(func.count(FileMetadataModel.id)).where(
FileMetadataModel.original_file_name == original_filename,
FileMetadataModel.source_id == source_id,
FileMetadataModel.organization_id == organization_id,
FileMetadataModel.is_deleted == False,
)
result = await session.execute(query)
count = result.scalar() or 0
if count == 0:
# No duplicates, return original filename
return original_filename
else:
# Add numeric suffix
return f"{base} ({count}){ext}"

View File

@@ -991,16 +991,17 @@ def create_uuid_from_string(val: str):
return uuid.UUID(hex=hex_string)
def sanitize_filename(filename: str) -> str:
def sanitize_filename(filename: str, add_uuid_suffix: bool = False) -> str:
"""
Sanitize the given filename to prevent directory traversal, invalid characters,
and reserved names while ensuring it fits within the maximum length allowed by the filesystem.
Parameters:
filename (str): The user-provided filename.
add_uuid_suffix (bool): If True, adds a UUID suffix for uniqueness (legacy behavior).
Returns:
str: A sanitized filename that is unique and safe for use.
str: A sanitized filename.
"""
# Extract the base filename to avoid directory components
filename = os.path.basename(filename)
@@ -1015,14 +1016,21 @@ def sanitize_filename(filename: str) -> str:
if base.startswith("."):
raise ValueError(f"Invalid filename - derived file name {base} cannot start with '.'")
# Truncate the base name to fit within the maximum allowed length
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_`
if len(base) > max_base_length:
base = base[:max_base_length]
if add_uuid_suffix:
# Legacy behavior: Truncate the base name to fit within the maximum allowed length
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_`
if len(base) > max_base_length:
base = base[:max_base_length]
# Append a unique UUID suffix for uniqueness
unique_suffix = uuid.uuid4().hex[:4]
sanitized_filename = f"{base}_{unique_suffix}{ext}"
# Append a unique UUID suffix for uniqueness
unique_suffix = uuid.uuid4().hex[:4]
sanitized_filename = f"{base}_{unique_suffix}{ext}"
else:
max_base_length = MAX_FILENAME_LENGTH - len(ext)
if len(base) > max_base_length:
base = base[:max_base_length]
sanitized_filename = f"{base}{ext}"
# Return the sanitized filename
return sanitized_filename

View File

@@ -126,6 +126,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
assert len(client.sources.list()) == 1
agent = client.agents.sources.attach(source_id=source_1.id, agent_id=agent.id)
assert len(client.agents.retrieve(agent_id=agent.id).sources) == 1
assert_file_tools_present(agent, set(FILES_TOOLS))
# Create and attach second source
@@ -133,6 +134,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
assert len(client.sources.list()) == 2
agent = client.agents.sources.attach(source_id=source_2.id, agent_id=agent.id)
assert len(client.agents.retrieve(agent_id=agent.id).sources) == 2
# File tools should remain after attaching second source
assert_file_tools_present(agent, set(FILES_TOOLS))
@@ -148,17 +150,17 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
@pytest.mark.parametrize(
"file_path, expected_value, expected_label_regex",
[
("tests/data/test.txt", "test", r"test_[a-z0-9]+\.txt"),
("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper_[a-z0-9]+\.pdf"),
("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning_[a-z0-9]+\.jsonl"),
("tests/data/test.md", "h2 Heading", r"test_[a-z0-9]+\.md"),
("tests/data/test.json", "glossary", r"test_[a-z0-9]+\.json"),
("tests/data/react_component.jsx", "UserProfile", r"react_component_[a-z0-9]+\.jsx"),
("tests/data/task_manager.java", "TaskManager", r"task_manager_[a-z0-9]+\.java"),
("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures_[a-z0-9]+\.cpp"),
("tests/data/api_server.go", "UserService", r"api_server_[a-z0-9]+\.go"),
("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis_[a-z0-9]+\.py"),
("tests/data/test.csv", "Smart Fridge Plus", r"test_[a-z0-9]+\.csv"),
("tests/data/test.txt", "test", r"test\.txt"),
("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper\.pdf"),
("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning\.jsonl"),
("tests/data/test.md", "h2 Heading", r"test\.md"),
("tests/data/test.json", "glossary", r"test\.json"),
("tests/data/react_component.jsx", "UserProfile", r"react_component\.jsx"),
("tests/data/task_manager.java", "TaskManager", r"task_manager\.java"),
("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures\.cpp"),
("tests/data/api_server.go", "UserService", r"api_server\.go"),
("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis\.py"),
("tests/data/test.csv", "Smart Fridge Plus", r"test\.csv"),
],
)
def test_file_upload_creates_source_blocks_correctly(
@@ -227,7 +229,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC
assert len(blocks) == 1
assert any("test" in b.value for b in blocks)
assert any(b.value.startswith("[Viewing file start") for b in blocks)
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks)
# Detach the source
client.agents.sources.detach(source_id=source.id, agent_id=agent_state.id)
@@ -259,7 +261,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a
blocks = agent_state.memory.file_blocks
assert len(blocks) == 1
assert any("test" in b.value for b in blocks)
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks)
# Remove file from source
client.sources.delete(source_id=source.id)
@@ -554,7 +556,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: Le
blocks = temp_agent_state.memory.file_blocks
assert len(blocks) == 1
assert any(b.value.startswith("[Viewing file start (out of 554 chunks)]") for b in blocks)
assert any(re.fullmatch(r"long_test_[a-z0-9]+\.txt", b.label) for b in blocks)
assert any(re.fullmatch(r"long_test\.txt", b.label) for b in blocks)
# Verify file tools were automatically attached
file_tools = {tool.name for tool in temp_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
@@ -624,6 +626,45 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta
)
def test_duplicate_file_renaming(client: LettaSDKClient):
"""Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)"""
# Create a new source
source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small")
# Upload the same file three times
file_path = "tests/data/test.txt"
with open(file_path, "rb") as f:
first_file = client.sources.files.upload(source_id=source.id, file=f)
with open(file_path, "rb") as f:
second_file = client.sources.files.upload(source_id=source.id, file=f)
with open(file_path, "rb") as f:
third_file = client.sources.files.upload(source_id=source.id, file=f)
# Get all uploaded files
files = client.sources.files.list(source_id=source.id, limit=10)
assert len(files) == 3, f"Expected 3 files, got {len(files)}"
# Sort files by creation time to ensure predictable order
files.sort(key=lambda f: f.created_at)
# Verify filenames follow the count-based pattern
expected_filenames = ["test.txt", "test (1).txt", "test (2).txt"]
actual_filenames = [f.file_name for f in files]
assert actual_filenames == expected_filenames, f"Expected {expected_filenames}, got {actual_filenames}"
# Verify all files have the same original_file_name
for file in files:
assert file.original_file_name == "test.txt", f"Expected original_file_name='test.txt', got '{file.original_file_name}'"
print(f"✓ Successfully tested duplicate file renaming:")
for i, file in enumerate(files):
print(f" File {i+1}: original='{file.original_file_name}' → renamed='{file.file_name}'")
def test_open_files_schema_descriptions(client: LettaSDKClient):
"""Test that open_files tool schema contains correct descriptions from docstring"""

View File

@@ -282,21 +282,21 @@ def test_coerce_dict_args_with_default_arguments():
def test_valid_filename():
filename = "valid_filename.txt"
sanitized = sanitize_filename(filename)
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
assert sanitized.startswith("valid_filename_")
assert sanitized.endswith(".txt")
def test_filename_with_special_characters():
filename = "invalid:/<>?*ƒfilename.txt"
sanitized = sanitize_filename(filename)
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
assert sanitized.startswith("ƒfilename_")
assert sanitized.endswith(".txt")
def test_null_byte_in_filename():
filename = "valid\0filename.txt"
sanitized = sanitize_filename(filename)
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
assert "\0" not in sanitized
assert sanitized.startswith("validfilename_")
assert sanitized.endswith(".txt")
@@ -304,13 +304,13 @@ def test_null_byte_in_filename():
def test_path_traversal_characters():
filename = "../../etc/passwd"
sanitized = sanitize_filename(filename)
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
assert sanitized.startswith("passwd_")
assert len(sanitized) <= MAX_FILENAME_LENGTH
def test_empty_filename():
sanitized = sanitize_filename("")
sanitized = sanitize_filename("", add_uuid_suffix=True)
assert sanitized.startswith("_")
@@ -326,15 +326,15 @@ def test_dotdot_as_filename():
def test_long_filename():
filename = "a" * (MAX_FILENAME_LENGTH + 10) + ".txt"
sanitized = sanitize_filename(filename)
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
assert len(sanitized) <= MAX_FILENAME_LENGTH
assert sanitized.endswith(".txt")
def test_unique_filenames():
filename = "duplicate.txt"
sanitized1 = sanitize_filename(filename)
sanitized2 = sanitize_filename(filename)
sanitized1 = sanitize_filename(filename, add_uuid_suffix=True)
sanitized2 = sanitize_filename(filename, add_uuid_suffix=True)
assert sanitized1 != sanitized2
assert sanitized1.startswith("duplicate_")
assert sanitized2.startswith("duplicate_")
@@ -342,6 +342,18 @@ def test_unique_filenames():
assert sanitized2.endswith(".txt")
def test_basic_sanitization_no_suffix():
"""Test the new behavior - basic sanitization without UUID suffix"""
filename = "test_file.txt"
sanitized = sanitize_filename(filename)
assert sanitized == "test_file.txt"
# Test with special characters
filename_with_chars = "test:/<>?*file.txt"
sanitized_chars = sanitize_filename(filename_with_chars)
assert sanitized_chars == "file.txt"
def test_formatter():
# Example system prompt that has no vars