feat: Only add suffix on duplication (#3120)
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
"""Add unique constraint to source names and also add original file name column
|
||||
|
||||
Revision ID: 46699adc71a7
|
||||
Revises: 1af251a42c06
|
||||
Create Date: 2025-07-01 13:30:48.279151
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "46699adc71a7"
|
||||
down_revision: Union[str, None] = "1af251a42c06"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("files", sa.Column("original_file_name", sa.String(), nullable=True))
|
||||
|
||||
# Handle existing duplicate source names before adding unique constraint
|
||||
connection = op.get_bind()
|
||||
|
||||
# Find duplicates and rename them by appending a suffix
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
WITH duplicates AS (
|
||||
SELECT name, organization_id,
|
||||
ROW_NUMBER() OVER (PARTITION BY name, organization_id ORDER BY created_at) as rn,
|
||||
id
|
||||
FROM sources
|
||||
WHERE (name, organization_id) IN (
|
||||
SELECT name, organization_id
|
||||
FROM sources
|
||||
GROUP BY name, organization_id
|
||||
HAVING COUNT(*) > 1
|
||||
)
|
||||
)
|
||||
SELECT id, name, rn
|
||||
FROM duplicates
|
||||
WHERE rn > 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename duplicates by appending a number suffix
|
||||
for row in result:
|
||||
source_id, original_name, duplicate_number = row
|
||||
new_name = f"{original_name}_{duplicate_number}"
|
||||
connection.execute(
|
||||
sa.text("UPDATE sources SET name = :new_name WHERE id = :source_id"), {"new_name": new_name, "source_id": source_id}
|
||||
)
|
||||
|
||||
op.create_unique_constraint("uq_source_name_organization", "sources", ["name", "organization_id"])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("uq_source_name_organization", "sources", type_="unique")
|
||||
op.drop_column("files", "original_file_name")
|
||||
# ### end Alembic commands ###
|
||||
@@ -361,3 +361,5 @@ REDIS_DEFAULT_CACHE_PREFIX = "letta_cache"
|
||||
|
||||
# TODO: This is temporary, eventually use token-based eviction
|
||||
MAX_FILES_OPEN = 5
|
||||
|
||||
GET_PROVIDERS_TIMEOUT_SECONDS = 10
|
||||
|
||||
@@ -305,7 +305,7 @@ class OpenAIClient(LLMClientBase):
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[dict]:
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config"""
|
||||
kwargs = self._prepare_client_kwargs_embedding(embedding_config)
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
|
||||
@@ -49,6 +49,7 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
|
||||
)
|
||||
|
||||
file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The name of the file.")
|
||||
original_file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The original name of the file as uploaded.")
|
||||
file_path: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The file path on the system.")
|
||||
file_type: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The type of the file.")
|
||||
file_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="The size of the file in bytes.")
|
||||
@@ -99,6 +100,7 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin, AsyncAttrs):
|
||||
organization_id=self.organization_id,
|
||||
source_id=self.source_id,
|
||||
file_name=self.file_name,
|
||||
original_file_name=self.original_file_name,
|
||||
file_path=self.file_path,
|
||||
file_type=self.file_type,
|
||||
file_size=self.file_size,
|
||||
|
||||
@@ -101,6 +101,6 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
|
||||
value=visible_content,
|
||||
label=self.file.file_name,
|
||||
read_only=True,
|
||||
source_id=self.file.source_id,
|
||||
metadata={"source_id": self.file.source_id},
|
||||
limit=CORE_MEMORY_SOURCE_CHAR_LIMIT,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, Index
|
||||
from sqlalchemy import JSON, Index, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm import FileMetadata
|
||||
@@ -25,6 +25,7 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
__table_args__ = (
|
||||
Index(f"source_created_at_id_idx", "created_at", "id"),
|
||||
UniqueConstraint("name", "organization_id", name="uq_source_name_organization"),
|
||||
{"extend_existing": True},
|
||||
)
|
||||
|
||||
|
||||
@@ -43,8 +43,8 @@ Recall memory (conversation history):
|
||||
Even though you can only see recent messages in your immediate context, you can search over your entire message history from a database.
|
||||
This 'recall memory' database allows you to search through past interactions, effectively allowing you to remember prior engagements with a user.
|
||||
|
||||
Folders and Files:
|
||||
You may be given access to a structured file system that mirrors real-world folders and files. Each folder may contain one or more files.
|
||||
Directories and Files:
|
||||
You may be given access to a structured file system that mirrors real-world directories and files. Each directory may contain one or more files.
|
||||
Files can include metadata (e.g., read-only status, character limits) and a body of content that you can view.
|
||||
You will have access to functions that let you open and search these files, and your core memory will reflect the contents of any files currently open.
|
||||
Maintain only those files relevant to the user’s current interaction.
|
||||
|
||||
@@ -315,9 +315,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
if agent_type == AgentType.react_agent or agent_type == AgentType.workflow_agent:
|
||||
return (
|
||||
"{% if sources %}"
|
||||
"<folders>\n"
|
||||
"<directories>\n"
|
||||
"{% for source in sources %}"
|
||||
f'<folder name="{{{{ source.name }}}}">\n'
|
||||
f'<directory name="{{{{ source.name }}}}">\n'
|
||||
"{% if source.description %}"
|
||||
"<description>{{ source.description }}</description>\n"
|
||||
"{% endif %}"
|
||||
@@ -326,7 +326,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
"{% endif %}"
|
||||
"{% if file_blocks %}"
|
||||
"{% for block in file_blocks %}"
|
||||
"{% if block.source_id == source.id %}"
|
||||
"{% if block.metadata['source_id'] == source.id %}"
|
||||
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\">\n"
|
||||
"<{{ block.label }}>\n"
|
||||
"<description>\n"
|
||||
@@ -344,9 +344,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% endif %}"
|
||||
"</folder>\n"
|
||||
"</directory>\n"
|
||||
"{% endfor %}"
|
||||
"</folders>"
|
||||
"</directories>"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
@@ -382,9 +382,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
"</tool_usage_rules>"
|
||||
"{% endif %}"
|
||||
"\n\n{% if sources %}"
|
||||
"<folders>\n"
|
||||
"<directories>\n"
|
||||
"{% for source in sources %}"
|
||||
f'<folder name="{{{{ source.name }}}}">\n'
|
||||
f'<directory name="{{{{ source.name }}}}">\n'
|
||||
"{% if source.description %}"
|
||||
"<description>{{ source.description }}</description>\n"
|
||||
"{% endif %}"
|
||||
@@ -393,7 +393,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
"{% endif %}"
|
||||
"{% if file_blocks %}"
|
||||
"{% for block in file_blocks %}"
|
||||
"{% if block.source_id == source.id %}"
|
||||
"{% if block.metadata['source_id'] == source.id %}"
|
||||
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\" name=\"{{{{ block.label }}}}\">\n"
|
||||
"{% if block.description %}"
|
||||
"<description>\n"
|
||||
@@ -414,9 +414,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% endif %}"
|
||||
"</folder>\n"
|
||||
"</directory>\n"
|
||||
"{% endfor %}"
|
||||
"</folders>"
|
||||
"</directories>"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
@@ -448,9 +448,9 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
"</tool_usage_rules>"
|
||||
"{% endif %}"
|
||||
"\n\n{% if sources %}"
|
||||
"<folders>\n"
|
||||
"<directories>\n"
|
||||
"{% for source in sources %}"
|
||||
f'<folder name="{{{{ source.name }}}}">\n'
|
||||
f'<directory name="{{{{ source.name }}}}">\n'
|
||||
"{% if source.description %}"
|
||||
"<description>{{ source.description }}</description>\n"
|
||||
"{% endif %}"
|
||||
@@ -459,7 +459,7 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
"{% endif %}"
|
||||
"{% if file_blocks %}"
|
||||
"{% for block in file_blocks %}"
|
||||
"{% if block.source_id == source.id %}"
|
||||
"{% if block.metadata['source_id'] == source.id %}"
|
||||
f"<file status=\"{{{{ '{FileStatus.open.value}' if block.value else '{FileStatus.closed.value}' }}}}\" name=\"{{{{ block.label }}}}\">\n"
|
||||
"{% if block.description %}"
|
||||
"<description>\n"
|
||||
@@ -480,8 +480,8 @@ def get_prompt_template_for_agent_type(agent_type: Optional[AgentType] = None):
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% endif %}"
|
||||
"</folder>\n"
|
||||
"</directory>\n"
|
||||
"{% endfor %}"
|
||||
"</folders>"
|
||||
"</directories>"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
@@ -33,9 +33,6 @@ class BaseBlock(LettaBase, validate_assignment=True):
|
||||
description: Optional[str] = Field(None, description="Description of the block.")
|
||||
metadata: Optional[dict] = Field({}, description="Metadata of the block.")
|
||||
|
||||
# source association (for file blocks)
|
||||
source_id: Optional[str] = Field(None, description="The source ID associated with this block (for file blocks).")
|
||||
|
||||
# def __len__(self):
|
||||
# return len(self.value)
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ class FileMetadata(FileMetadataBase):
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the document.")
|
||||
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
|
||||
file_name: Optional[str] = Field(None, description="The name of the file.")
|
||||
original_file_name: Optional[str] = Field(None, description="The original name of the file as uploaded.")
|
||||
file_path: Optional[str] = Field(None, description="The path to the file.")
|
||||
file_type: Optional[str] = Field(None, description="The type of the file (MIME type).")
|
||||
file_size: Optional[int] = Field(None, description="The size of the file in bytes.")
|
||||
|
||||
@@ -184,6 +184,20 @@ async def upload_file_to_source(
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
"""
|
||||
# NEW: Cloud based file processing
|
||||
# Determine file's MIME type
|
||||
file_mime_type = mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||
|
||||
# Check if it's a simple text file
|
||||
is_simple_file = is_simple_text_mime_type(file_mime_type)
|
||||
|
||||
# For complex files, require Mistral API key
|
||||
if not is_simple_file and not settings.mistral_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Mistral API key is required to process this file type {file_mime_type}. Please configure your Mistral API key to upload complex file formats.",
|
||||
)
|
||||
|
||||
allowed_media_types = get_allowed_media_types()
|
||||
|
||||
# Normalize incoming Content-Type header (strip charset or any parameters).
|
||||
@@ -220,15 +234,19 @@ async def upload_file_to_source(
|
||||
|
||||
content = await file.read()
|
||||
|
||||
# sanitize filename
|
||||
file.filename = sanitize_filename(file.filename)
|
||||
# Store original filename and generate unique filename
|
||||
original_filename = sanitize_filename(file.filename) # Basic sanitization only
|
||||
unique_filename = await server.file_manager.generate_unique_filename(
|
||||
original_filename=original_filename, source_id=source_id, organization_id=actor.organization_id
|
||||
)
|
||||
|
||||
# create file metadata
|
||||
file_metadata = FileMetadata(
|
||||
source_id=source_id,
|
||||
file_name=file.filename,
|
||||
file_name=unique_filename,
|
||||
original_file_name=original_filename,
|
||||
file_path=None,
|
||||
file_type=mimetypes.guess_type(file.filename)[0] or file.content_type or "unknown",
|
||||
file_type=mimetypes.guess_type(original_filename)[0] or file.content_type or "unknown",
|
||||
file_size=file.size if file.size is not None else None,
|
||||
processing_status=FileProcessingStatus.PARSING,
|
||||
)
|
||||
@@ -237,20 +255,6 @@ async def upload_file_to_source(
|
||||
# TODO: Do we need to pull in the full agent_states? Can probably simplify here right?
|
||||
agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
|
||||
# NEW: Cloud based file processing
|
||||
# Determine file's MIME type
|
||||
file_mime_type = mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||
|
||||
# Check if it's a simple text file
|
||||
is_simple_file = is_simple_text_mime_type(file_mime_type)
|
||||
|
||||
# For complex files, require Mistral API key
|
||||
if not is_simple_file and not settings.mistral_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Mistral API key is required to process this file type {file_mime_type}. Please configure your Mistral API key to upload complex file formats.",
|
||||
)
|
||||
|
||||
# Use cloud processing for all files (simple files always, complex files with Mistral key)
|
||||
logger.info("Running experimental cloud based file processing...")
|
||||
safe_create_task(
|
||||
|
||||
@@ -1637,12 +1637,14 @@ class SyncServer(Server):
|
||||
|
||||
async def get_provider_models(provider: Provider) -> list[LLMConfig]:
|
||||
try:
|
||||
return await provider.list_llm_models_async()
|
||||
async with asyncio.timeout(constants.GET_PROVIDERS_TIMEOUT_SECONDS):
|
||||
return await provider.list_llm_models_async()
|
||||
except asyncio.TimeoutError:
|
||||
warnings.warn(f"Timeout while listing LLM models for provider {provider}")
|
||||
return []
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
|
||||
warnings.warn(f"Error while listing LLM models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Execute all provider model listing tasks concurrently
|
||||
|
||||
@@ -1779,7 +1779,7 @@ class AgentManager:
|
||||
relationship_name="sources",
|
||||
model_class=SourceModel,
|
||||
item_ids=[source_id],
|
||||
allow_partial=False, # Extend existing sources rather than replace
|
||||
replace=False,
|
||||
)
|
||||
|
||||
# Commit the changes
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from letta.constants import MAX_FILENAME_LENGTH
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.file import FileContent as FileContentModel
|
||||
from letta.orm.file import FileMetadata as FileMetadataModel
|
||||
@@ -217,3 +219,44 @@ class FileManager:
|
||||
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
|
||||
await file.hard_delete_async(db_session=session, actor=actor)
|
||||
return await file.to_pydantic_async()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def generate_unique_filename(self, original_filename: str, source_id: str, organization_id: str) -> str:
|
||||
"""
|
||||
Generate a unique filename by checking for duplicates and adding a numeric suffix if needed.
|
||||
Similar to how filesystems handle duplicates (e.g., file.txt, file (1).txt, file (2).txt).
|
||||
|
||||
Parameters:
|
||||
original_filename (str): The original filename as uploaded.
|
||||
source_id (str): Source ID to check for duplicates within.
|
||||
organization_id (str): Organization ID to check for duplicates within.
|
||||
|
||||
Returns:
|
||||
str: A unique filename with numeric suffix if needed.
|
||||
"""
|
||||
base, ext = os.path.splitext(original_filename)
|
||||
|
||||
# Reserve space for potential suffix: " (999)" = 6 characters
|
||||
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 6
|
||||
if len(base) > max_base_length:
|
||||
base = base[:max_base_length]
|
||||
original_filename = f"{base}{ext}"
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Count existing files with the same original_file_name in this source
|
||||
query = select(func.count(FileMetadataModel.id)).where(
|
||||
FileMetadataModel.original_file_name == original_filename,
|
||||
FileMetadataModel.source_id == source_id,
|
||||
FileMetadataModel.organization_id == organization_id,
|
||||
FileMetadataModel.is_deleted == False,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
count = result.scalar() or 0
|
||||
|
||||
if count == 0:
|
||||
# No duplicates, return original filename
|
||||
return original_filename
|
||||
else:
|
||||
# Add numeric suffix
|
||||
return f"{base} ({count}){ext}"
|
||||
|
||||
@@ -991,16 +991,17 @@ def create_uuid_from_string(val: str):
|
||||
return uuid.UUID(hex=hex_string)
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
def sanitize_filename(filename: str, add_uuid_suffix: bool = False) -> str:
|
||||
"""
|
||||
Sanitize the given filename to prevent directory traversal, invalid characters,
|
||||
and reserved names while ensuring it fits within the maximum length allowed by the filesystem.
|
||||
|
||||
Parameters:
|
||||
filename (str): The user-provided filename.
|
||||
add_uuid_suffix (bool): If True, adds a UUID suffix for uniqueness (legacy behavior).
|
||||
|
||||
Returns:
|
||||
str: A sanitized filename that is unique and safe for use.
|
||||
str: A sanitized filename.
|
||||
"""
|
||||
# Extract the base filename to avoid directory components
|
||||
filename = os.path.basename(filename)
|
||||
@@ -1015,14 +1016,21 @@ def sanitize_filename(filename: str) -> str:
|
||||
if base.startswith("."):
|
||||
raise ValueError(f"Invalid filename - derived file name {base} cannot start with '.'")
|
||||
|
||||
# Truncate the base name to fit within the maximum allowed length
|
||||
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_`
|
||||
if len(base) > max_base_length:
|
||||
base = base[:max_base_length]
|
||||
if add_uuid_suffix:
|
||||
# Legacy behavior: Truncate the base name to fit within the maximum allowed length
|
||||
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_`
|
||||
if len(base) > max_base_length:
|
||||
base = base[:max_base_length]
|
||||
|
||||
# Append a unique UUID suffix for uniqueness
|
||||
unique_suffix = uuid.uuid4().hex[:4]
|
||||
sanitized_filename = f"{base}_{unique_suffix}{ext}"
|
||||
# Append a unique UUID suffix for uniqueness
|
||||
unique_suffix = uuid.uuid4().hex[:4]
|
||||
sanitized_filename = f"{base}_{unique_suffix}{ext}"
|
||||
else:
|
||||
max_base_length = MAX_FILENAME_LENGTH - len(ext)
|
||||
if len(base) > max_base_length:
|
||||
base = base[:max_base_length]
|
||||
|
||||
sanitized_filename = f"{base}{ext}"
|
||||
|
||||
# Return the sanitized filename
|
||||
return sanitized_filename
|
||||
|
||||
@@ -126,6 +126,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
|
||||
assert len(client.sources.list()) == 1
|
||||
|
||||
agent = client.agents.sources.attach(source_id=source_1.id, agent_id=agent.id)
|
||||
assert len(client.agents.retrieve(agent_id=agent.id).sources) == 1
|
||||
assert_file_tools_present(agent, set(FILES_TOOLS))
|
||||
|
||||
# Create and attach second source
|
||||
@@ -133,6 +134,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
|
||||
assert len(client.sources.list()) == 2
|
||||
|
||||
agent = client.agents.sources.attach(source_id=source_2.id, agent_id=agent.id)
|
||||
assert len(client.agents.retrieve(agent_id=agent.id).sources) == 2
|
||||
# File tools should remain after attaching second source
|
||||
assert_file_tools_present(agent, set(FILES_TOOLS))
|
||||
|
||||
@@ -148,17 +150,17 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
|
||||
@pytest.mark.parametrize(
|
||||
"file_path, expected_value, expected_label_regex",
|
||||
[
|
||||
("tests/data/test.txt", "test", r"test_[a-z0-9]+\.txt"),
|
||||
("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper_[a-z0-9]+\.pdf"),
|
||||
("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning_[a-z0-9]+\.jsonl"),
|
||||
("tests/data/test.md", "h2 Heading", r"test_[a-z0-9]+\.md"),
|
||||
("tests/data/test.json", "glossary", r"test_[a-z0-9]+\.json"),
|
||||
("tests/data/react_component.jsx", "UserProfile", r"react_component_[a-z0-9]+\.jsx"),
|
||||
("tests/data/task_manager.java", "TaskManager", r"task_manager_[a-z0-9]+\.java"),
|
||||
("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures_[a-z0-9]+\.cpp"),
|
||||
("tests/data/api_server.go", "UserService", r"api_server_[a-z0-9]+\.go"),
|
||||
("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis_[a-z0-9]+\.py"),
|
||||
("tests/data/test.csv", "Smart Fridge Plus", r"test_[a-z0-9]+\.csv"),
|
||||
("tests/data/test.txt", "test", r"test\.txt"),
|
||||
("tests/data/memgpt_paper.pdf", "MemGPT", r"memgpt_paper\.pdf"),
|
||||
("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning\.jsonl"),
|
||||
("tests/data/test.md", "h2 Heading", r"test\.md"),
|
||||
("tests/data/test.json", "glossary", r"test\.json"),
|
||||
("tests/data/react_component.jsx", "UserProfile", r"react_component\.jsx"),
|
||||
("tests/data/task_manager.java", "TaskManager", r"task_manager\.java"),
|
||||
("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures\.cpp"),
|
||||
("tests/data/api_server.go", "UserService", r"api_server\.go"),
|
||||
("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis\.py"),
|
||||
("tests/data/test.csv", "Smart Fridge Plus", r"test\.csv"),
|
||||
],
|
||||
)
|
||||
def test_file_upload_creates_source_blocks_correctly(
|
||||
@@ -227,7 +229,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC
|
||||
assert len(blocks) == 1
|
||||
assert any("test" in b.value for b in blocks)
|
||||
assert any(b.value.startswith("[Viewing file start") for b in blocks)
|
||||
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
|
||||
assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks)
|
||||
|
||||
# Detach the source
|
||||
client.agents.sources.detach(source_id=source.id, agent_id=agent_state.id)
|
||||
@@ -259,7 +261,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a
|
||||
blocks = agent_state.memory.file_blocks
|
||||
assert len(blocks) == 1
|
||||
assert any("test" in b.value for b in blocks)
|
||||
assert any(re.fullmatch(r"test_[a-z0-9]+\.txt", b.label) for b in blocks)
|
||||
assert any(re.fullmatch(r"test\.txt", b.label) for b in blocks)
|
||||
|
||||
# Remove file from source
|
||||
client.sources.delete(source_id=source.id)
|
||||
@@ -554,7 +556,7 @@ def test_create_agent_with_source_ids_creates_source_blocks_correctly(client: Le
|
||||
blocks = temp_agent_state.memory.file_blocks
|
||||
assert len(blocks) == 1
|
||||
assert any(b.value.startswith("[Viewing file start (out of 554 chunks)]") for b in blocks)
|
||||
assert any(re.fullmatch(r"long_test_[a-z0-9]+\.txt", b.label) for b in blocks)
|
||||
assert any(re.fullmatch(r"long_test\.txt", b.label) for b in blocks)
|
||||
|
||||
# Verify file tools were automatically attached
|
||||
file_tools = {tool.name for tool in temp_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||||
@@ -624,6 +626,45 @@ def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentSta
|
||||
)
|
||||
|
||||
|
||||
def test_duplicate_file_renaming(client: LettaSDKClient):
|
||||
"""Test that duplicate files are renamed with count-based suffixes (e.g., file.txt, file (1).txt, file (2).txt)"""
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_duplicate_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Upload the same file three times
|
||||
file_path = "tests/data/test.txt"
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
first_file = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
second_file = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
third_file = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
# Get all uploaded files
|
||||
files = client.sources.files.list(source_id=source.id, limit=10)
|
||||
assert len(files) == 3, f"Expected 3 files, got {len(files)}"
|
||||
|
||||
# Sort files by creation time to ensure predictable order
|
||||
files.sort(key=lambda f: f.created_at)
|
||||
|
||||
# Verify filenames follow the count-based pattern
|
||||
expected_filenames = ["test.txt", "test (1).txt", "test (2).txt"]
|
||||
actual_filenames = [f.file_name for f in files]
|
||||
|
||||
assert actual_filenames == expected_filenames, f"Expected {expected_filenames}, got {actual_filenames}"
|
||||
|
||||
# Verify all files have the same original_file_name
|
||||
for file in files:
|
||||
assert file.original_file_name == "test.txt", f"Expected original_file_name='test.txt', got '{file.original_file_name}'"
|
||||
|
||||
print(f"✓ Successfully tested duplicate file renaming:")
|
||||
for i, file in enumerate(files):
|
||||
print(f" File {i+1}: original='{file.original_file_name}' → renamed='{file.file_name}'")
|
||||
|
||||
|
||||
def test_open_files_schema_descriptions(client: LettaSDKClient):
|
||||
"""Test that open_files tool schema contains correct descriptions from docstring"""
|
||||
|
||||
|
||||
@@ -282,21 +282,21 @@ def test_coerce_dict_args_with_default_arguments():
|
||||
|
||||
def test_valid_filename():
|
||||
filename = "valid_filename.txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
|
||||
assert sanitized.startswith("valid_filename_")
|
||||
assert sanitized.endswith(".txt")
|
||||
|
||||
|
||||
def test_filename_with_special_characters():
|
||||
filename = "invalid:/<>?*ƒfilename.txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
|
||||
assert sanitized.startswith("ƒfilename_")
|
||||
assert sanitized.endswith(".txt")
|
||||
|
||||
|
||||
def test_null_byte_in_filename():
|
||||
filename = "valid\0filename.txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
|
||||
assert "\0" not in sanitized
|
||||
assert sanitized.startswith("validfilename_")
|
||||
assert sanitized.endswith(".txt")
|
||||
@@ -304,13 +304,13 @@ def test_null_byte_in_filename():
|
||||
|
||||
def test_path_traversal_characters():
|
||||
filename = "../../etc/passwd"
|
||||
sanitized = sanitize_filename(filename)
|
||||
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
|
||||
assert sanitized.startswith("passwd_")
|
||||
assert len(sanitized) <= MAX_FILENAME_LENGTH
|
||||
|
||||
|
||||
def test_empty_filename():
|
||||
sanitized = sanitize_filename("")
|
||||
sanitized = sanitize_filename("", add_uuid_suffix=True)
|
||||
assert sanitized.startswith("_")
|
||||
|
||||
|
||||
@@ -326,15 +326,15 @@ def test_dotdot_as_filename():
|
||||
|
||||
def test_long_filename():
|
||||
filename = "a" * (MAX_FILENAME_LENGTH + 10) + ".txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
sanitized = sanitize_filename(filename, add_uuid_suffix=True)
|
||||
assert len(sanitized) <= MAX_FILENAME_LENGTH
|
||||
assert sanitized.endswith(".txt")
|
||||
|
||||
|
||||
def test_unique_filenames():
|
||||
filename = "duplicate.txt"
|
||||
sanitized1 = sanitize_filename(filename)
|
||||
sanitized2 = sanitize_filename(filename)
|
||||
sanitized1 = sanitize_filename(filename, add_uuid_suffix=True)
|
||||
sanitized2 = sanitize_filename(filename, add_uuid_suffix=True)
|
||||
assert sanitized1 != sanitized2
|
||||
assert sanitized1.startswith("duplicate_")
|
||||
assert sanitized2.startswith("duplicate_")
|
||||
@@ -342,6 +342,18 @@ def test_unique_filenames():
|
||||
assert sanitized2.endswith(".txt")
|
||||
|
||||
|
||||
def test_basic_sanitization_no_suffix():
|
||||
"""Test the new behavior - basic sanitization without UUID suffix"""
|
||||
filename = "test_file.txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
assert sanitized == "test_file.txt"
|
||||
|
||||
# Test with special characters
|
||||
filename_with_chars = "test:/<>?*file.txt"
|
||||
sanitized_chars = sanitize_filename(filename_with_chars)
|
||||
assert sanitized_chars == "file.txt"
|
||||
|
||||
|
||||
def test_formatter():
|
||||
|
||||
# Example system prompt that has no vars
|
||||
|
||||
Reference in New Issue
Block a user