Files
letta-server/letta/services/file_manager.py
Kian Jones f5c4ab50f4 chore: add ty + pre-commit hook and repeal even more ruff rules (#9504)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import

* remove all ignores, add FastAPI rules and Ruff rules

* add ty and precommit

* ruff stuff

* ty check fixes

* ty check fixes pt 2

* error on invalid
2026-02-24 10:55:11 -08:00

731 lines
32 KiB
Python

import os
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from sqlalchemy import func, select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from letta.constants import MAX_FILENAME_LENGTH
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
from letta.orm.sqlalchemy_base import AccessType
from letta.otel.tracing import trace_method
from letta.schemas.enums import FileProcessingStatus, PrimitiveType
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, SourceStats
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import settings
from letta.utils import bounded_gather, enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
class DuplicateFileError(Exception):
"""Raised when a duplicate file is encountered and error handling is specified"""
def __init__(self, filename: str, source_name: str):
self.filename = filename
self.source_name = source_name
super().__init__(f"File '{filename}' already exists in source '{source_name}'")
class FileManager:
"""Manager class to handle business logic related to files."""
async def _invalidate_file_caches(
self, file_id: str, actor: PydanticUser, original_filename: str | None = None, source_id: str | None = None
):
"""Invalidate all caches related to a file."""
# TEMPORARILY DISABLED - caching is disabled
# # invalidate file content cache (all variants)
# await self.get_file_by_id.cache_invalidate(self, file_id, actor, include_content=True)
# await self.get_file_by_id.cache_invalidate(self, file_id, actor, include_content=False)
# # invalidate filename-based cache if we have the info
# if original_filename and source_id:
# await self.get_file_by_original_name_and_source.cache_invalidate(self, original_filename, source_id, actor)
@enforce_types
@trace_method
async def create_file(
self,
file_metadata: PydanticFileMetadata,
actor: PydanticUser,
*,
text: Optional[str] = None,
) -> PydanticFileMetadata:
# short-circuit if it already exists
try:
existing = await self.get_file_by_id(file_metadata.id, actor=actor)
except NoResultFound:
existing = None
if existing:
return existing
async with db_registry.async_session() as session:
try:
file_metadata.organization_id = actor.organization_id
file_orm = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
await file_orm.create_async(session, actor=actor, no_commit=True)
if text is not None:
content_orm = FileContentModel(file_id=file_orm.id, text=text)
await content_orm.create_async(session, actor=actor, no_commit=True)
await session.commit()
await session.refresh(file_orm)
# invalidate cache for this new file
await self._invalidate_file_caches(file_orm.id, actor, file_orm.original_file_name, file_orm.source_id)
return file_orm.to_pydantic()
except IntegrityError:
await session.rollback()
return await self.get_file_by_id(file_metadata.id, actor=actor)
@enforce_types
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
@trace_method
# @async_redis_cache(
# key_func=lambda self, file_id, actor, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id}:{include_content}:{strip_directory_prefix}",
# prefix="file_content",
# ttl_s=3600,
# model_class=PydanticFileMetadata,
# )
async def get_file_by_id(
self, file_id: str, actor: PydanticUser, *, include_content: bool = False, strip_directory_prefix: bool = False
) -> Optional[PydanticFileMetadata]:
"""Retrieve a file by its ID.
If `include_content=True`, the FileContent relationship is eagerly
loaded so `to_pydantic(include_content=True)` never triggers a
lazy SELECT (avoids MissingGreenlet).
"""
async with db_registry.async_session() as session:
if include_content:
# explicit eager load
query = select(FileMetadataModel).where(FileMetadataModel.id == file_id).options(selectinload(FileMetadataModel.content))
# apply org-scoping if actor provided
if actor:
query = FileMetadataModel.apply_access_predicate(
query,
actor,
access=["read"],
access_type=AccessType.ORGANIZATION,
)
result = await session.execute(query)
file_orm = result.scalar_one_or_none()
else:
# fast path (metadata only)
try:
file_orm = await FileMetadataModel.read_async(
db_session=session,
identifier=file_id,
actor=actor,
)
except NoResultFound:
return None
if file_orm is None:
return None
return await file_orm.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
@enforce_types
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
@trace_method
async def update_file_status(
self,
*,
file_id: str,
actor: PydanticUser,
processing_status: Optional[FileProcessingStatus] = None,
error_message: Optional[str] = None,
total_chunks: Optional[int] = None,
chunks_embedded: Optional[int] = None,
enforce_state_transitions: bool = True,
) -> Optional[PydanticFileMetadata]:
"""
Update processing_status, error_message, total_chunks, and/or chunks_embedded on a FileMetadata row.
Enforces state transition rules (when enforce_state_transitions=True):
- PENDING -> PARSING -> EMBEDDING -> COMPLETED (normal flow)
- Any non-terminal state -> ERROR
- Same-state transitions are allowed (e.g., EMBEDDING -> EMBEDDING)
- ERROR and COMPLETED are terminal (no status transitions allowed, metadata updates blocked)
Args:
file_id: ID of the file to update
actor: User performing the update
processing_status: New processing status to set
error_message: Error message to set (if any)
total_chunks: Total number of chunks in the file
chunks_embedded: Number of chunks already embedded
enforce_state_transitions: Whether to enforce state transition rules (default: True).
Set to False to bypass validation for testing or special cases.
Returns:
Updated file metadata, or None if the update was blocked
* 1st round-trip → UPDATE with optional state validation
* 2nd round-trip → SELECT fresh row (same as read_async) if update succeeded
"""
if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None:
raise ValueError("Nothing to update")
# validate that ERROR status must have an error message
if processing_status == FileProcessingStatus.ERROR and not error_message:
raise ValueError("Error message is required when setting processing status to ERROR")
values: dict[str, object] = {"updated_at": datetime.utcnow()}
if processing_status is not None:
values["processing_status"] = processing_status
if error_message is not None:
values["error_message"] = error_message
if total_chunks is not None:
values["total_chunks"] = total_chunks
if chunks_embedded is not None:
values["chunks_embedded"] = chunks_embedded
# validate state transitions before making any database calls
if enforce_state_transitions and processing_status == FileProcessingStatus.PENDING:
# PENDING cannot be set after initial creation
raise ValueError(f"Cannot transition to PENDING state for file {file_id} - PENDING is only valid as initial state")
async with db_registry.async_session() as session:
# build where conditions
where_conditions = [
FileMetadataModel.id == file_id,
FileMetadataModel.organization_id == actor.organization_id,
]
# only add state transition validation if enforce_state_transitions is True
if enforce_state_transitions and processing_status is not None:
# enforce specific transitions based on target status
if processing_status == FileProcessingStatus.PARSING:
where_conditions.append(
FileMetadataModel.processing_status.in_([FileProcessingStatus.PENDING, FileProcessingStatus.PARSING])
)
elif processing_status == FileProcessingStatus.EMBEDDING:
where_conditions.append(
FileMetadataModel.processing_status.in_([FileProcessingStatus.PARSING, FileProcessingStatus.EMBEDDING])
)
elif processing_status == FileProcessingStatus.COMPLETED:
where_conditions.append(
FileMetadataModel.processing_status.in_([FileProcessingStatus.EMBEDDING, FileProcessingStatus.COMPLETED])
)
elif processing_status == FileProcessingStatus.ERROR:
# ERROR can be set from any non-terminal state
where_conditions.append(
FileMetadataModel.processing_status.notin_([FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED])
)
elif enforce_state_transitions and processing_status is None:
# If only updating metadata fields (not status), prevent updates to terminal states
where_conditions.append(
FileMetadataModel.processing_status.notin_([FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED])
)
# fast in-place update with state validation
stmt = (
update(FileMetadataModel)
.where(*where_conditions)
.values(**values)
.returning(FileMetadataModel.id) # return id if update succeeded
)
result = await session.execute(stmt)
updated_id = result.scalar()
if not updated_id:
# update was blocked
await session.commit()
if enforce_state_transitions:
# update was blocked by state transition rules - raise error
# fetch current state to provide informative error
current_file = await FileMetadataModel.read_async(
db_session=session,
identifier=file_id,
actor=actor,
)
current_status = current_file.processing_status
# build informative error message
if processing_status is not None:
if current_status in [FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED]:
raise ValueError(
f"Cannot update file {file_id} status from terminal state {current_status} to {processing_status}"
)
else:
raise ValueError(f"Invalid state transition for file {file_id}: {current_status} -> {processing_status}")
else:
raise ValueError(f"Cannot update file {file_id} in terminal state {current_status}")
else:
# validation was bypassed but update still failed (e.g., file doesn't exist)
return None
await session.commit()
# invalidate cache for this file
await self._invalidate_file_caches(file_id, actor)
# reload via normal accessor so we return a fully-attached object
file_orm = await FileMetadataModel.read_async(
db_session=session,
identifier=file_id,
actor=actor,
)
return file_orm.to_pydantic()
@enforce_types
@trace_method
async def check_and_update_file_status(
self,
file_metadata: PydanticFileMetadata,
actor: PydanticUser,
) -> PydanticFileMetadata:
"""
Check and update file status for timeout and embedding completion.
This method consolidates logic for:
1. Checking if a file has timed out during processing
2. Checking Pinecone embedding status and updating counts
Args:
file_metadata: The file metadata to check
actor: User performing the check
Returns:
Updated file metadata with current status
"""
# check for timeout if status is not terminal
if not file_metadata.processing_status.is_terminal_state():
if file_metadata.created_at:
# handle timezone differences between PostgreSQL (timezone-aware) and SQLite (timezone-naive)
if settings.letta_pg_uri_no_default:
# postgresql: both datetimes are timezone-aware
timeout_threshold = datetime.now(timezone.utc) - timedelta(minutes=settings.file_processing_timeout_minutes)
file_created_at = file_metadata.created_at
else:
# sqlite: both datetimes should be timezone-naive
timeout_threshold = datetime.utcnow() - timedelta(minutes=settings.file_processing_timeout_minutes)
file_created_at = file_metadata.created_at
if file_created_at < timeout_threshold:
# move file to error status with timeout message
timeout_message = settings.file_processing_timeout_error_message.format(settings.file_processing_timeout_minutes)
try:
file_metadata = await self.update_file_status(
file_id=file_metadata.id,
actor=actor,
processing_status=FileProcessingStatus.ERROR,
error_message=timeout_message,
)
except ValueError as e:
# state transition was blocked - log it but don't fail
logger.warning(f"Could not update file to timeout error state: {str(e)}")
# continue with existing file_metadata
# check pinecone embedding status
if should_use_pinecone() and file_metadata.processing_status == FileProcessingStatus.EMBEDDING:
ids = await list_pinecone_index_for_files(file_id=file_metadata.id, actor=actor)
logger.info(
f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_metadata.id} ({file_metadata.file_name}) in organization {actor.organization_id}"
)
if len(ids) != file_metadata.chunks_embedded or len(ids) == file_metadata.total_chunks:
if len(ids) != file_metadata.total_chunks:
file_status = file_metadata.processing_status
else:
file_status = FileProcessingStatus.COMPLETED
try:
file_metadata = await self.update_file_status(
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
)
except ValueError as e:
# state transition was blocked - this is a race condition
# log it but don't fail since we're just checking status
logger.warning(f"Race condition detected in check_and_update_file_status: {str(e)}")
# return the current file state without updating
return file_metadata
@enforce_types
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
@trace_method
async def upsert_file_content(
self,
*,
file_id: str,
text: str,
actor: PydanticUser,
) -> PydanticFileMetadata:
async with db_registry.async_session() as session:
await FileMetadataModel.read_async(session, file_id, actor)
dialect_name = session.bind.dialect.name
if dialect_name == "postgresql":
stmt = (
pg_insert(FileContentModel)
.values(file_id=file_id, text=text)
.on_conflict_do_update(
index_elements=[FileContentModel.file_id],
set_={"text": text},
)
)
await session.execute(stmt)
else:
# Emulate upsert for SQLite and others
stmt = select(FileContentModel).where(FileContentModel.file_id == file_id)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
await session.execute(update(FileContentModel).where(FileContentModel.file_id == file_id).values(text=text))
else:
session.add(FileContentModel(file_id=file_id, text=text))
await session.commit()
# invalidate cache for this file since content changed
await self._invalidate_file_caches(file_id, actor)
# Reload with content
query = select(FileMetadataModel).options(selectinload(FileMetadataModel.content)).where(FileMetadataModel.id == file_id)
result = await session.execute(query)
return await result.scalar_one().to_pydantic_async(include_content=True)
@enforce_types
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
@trace_method
async def list_files(
self,
source_id: str,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 1000,
ascending: Optional[bool] = True,
include_content: bool = False,
strip_directory_prefix: bool = False,
check_status_updates: bool = False,
) -> List[PydanticFileMetadata]:
"""List all files with optional pagination and status checking.
Args:
source_id: Source to list files from
actor: User performing the request
before: Before filter
after: Pagination cursor
limit: Maximum number of files to return
ascending: Sort by ascending or descending order
include_content: Whether to include file content
strip_directory_prefix: Whether to strip directory prefix from filenames
check_status_updates: Whether to check and update status for timeout and embedding completion
Returns:
List of file metadata
"""
async with db_registry.async_session() as session:
options = [selectinload(FileMetadataModel.content)] if include_content else None
files = await FileMetadataModel.list_async(
db_session=session,
before=before,
after=after,
limit=limit,
ascending=ascending,
organization_id=actor.organization_id,
source_id=source_id,
query_options=options,
)
# convert all files to pydantic models
if include_content:
file_metadatas = await bounded_gather(
[
file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
for file in files
]
)
else:
file_metadatas = [file.to_pydantic(strip_directory_prefix=strip_directory_prefix) for file in files]
# if status checking is enabled, check all files sequentially to avoid db pool exhaustion
# Each status check may update the file in the database, so concurrent checks with many
# files can create too many simultaneous database connections
if check_status_updates:
updated_file_metadatas = []
for file_metadata in file_metadatas:
updated_metadata = await self.check_and_update_file_status(file_metadata, actor)
updated_file_metadatas.append(updated_metadata)
file_metadatas = updated_file_metadatas
return file_metadatas
@enforce_types
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
@trace_method
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
"""Delete a file by its ID."""
async with db_registry.async_session() as session:
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor)
# invalidate cache for this file before deletion
await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id)
await file.hard_delete_async(db_session=session, actor=actor)
return file.to_pydantic()
@enforce_types
@trace_method
async def generate_unique_filename(self, original_filename: str, source: PydanticSource, organization_id: str) -> str:
"""
Generate a unique filename by adding a numeric suffix if duplicates exist.
Always returns a unique filename - does not handle duplicate policies.
Parameters:
original_filename (str): The original filename as uploaded.
source (PydanticSource): Source to check for duplicates within.
organization_id (str): Organization ID to check for duplicates within.
Returns:
str: A unique filename with source.name prefix and numeric suffix if needed.
"""
base, ext = os.path.splitext(original_filename)
# Reserve space for potential suffix: " (999)" = 6 characters
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 6
if len(base) > max_base_length:
base = base[:max_base_length]
original_filename = f"{base}{ext}"
async with db_registry.async_session() as session:
# Count existing files with the same original_file_name in this source
query = select(func.count(FileMetadataModel.id)).where(
FileMetadataModel.original_file_name == original_filename,
FileMetadataModel.source_id == source.id,
FileMetadataModel.organization_id == organization_id,
FileMetadataModel.is_deleted == False,
)
result = await session.execute(query)
count = result.scalar() or 0
if count == 0:
# No duplicates, return original filename with source.name
return f"{source.name}/{original_filename}"
else:
# Add numeric suffix to make unique
return f"{source.name}/{base}_({count}){ext}"
@enforce_types
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
@trace_method
# @async_redis_cache(
# key_func=lambda self, original_filename, source_id, actor: f"{original_filename}:{source_id}:{actor.organization_id}",
# prefix="file_by_name",
# ttl_s=3600,
# model_class=PydanticFileMetadata,
# )
async def get_file_by_original_name_and_source(
self, original_filename: str, source_id: str, actor: PydanticUser
) -> Optional[PydanticFileMetadata]:
"""
Get a file by its original filename and source ID.
Parameters:
original_filename (str): The original filename to search for.
source_id (str): The source ID to search within.
actor (PydanticUser): The actor performing the request.
Returns:
Optional[PydanticFileMetadata]: The file metadata if found, None otherwise.
"""
async with db_registry.async_session() as session:
query = (
select(FileMetadataModel)
.where(
FileMetadataModel.original_file_name == original_filename,
FileMetadataModel.source_id == source_id,
FileMetadataModel.organization_id == actor.organization_id,
FileMetadataModel.is_deleted == False,
)
.limit(1)
)
result = await session.execute(query)
file_orm = result.scalar_one_or_none()
if file_orm:
return file_orm.to_pydantic()
return None
@enforce_types
@trace_method
async def get_organization_sources_metadata(
self, actor: PydanticUser, include_detailed_per_source_metadata: bool = False
) -> OrganizationSourcesStats:
"""
Get aggregated metadata for all sources in an organization with optimized queries.
Returns structured metadata including:
- Total number of sources
- Total number of files across all sources
- Total size of all files
- Per-source breakdown with file details (if include_detailed_per_source_metadata is True)
"""
async with db_registry.async_session() as session:
# Import here to avoid circular imports
from letta.orm.source import Source as SourceModel
# Single optimized query to get all sources with their file aggregations
query = (
select(
SourceModel.id,
SourceModel.name,
func.count(FileMetadataModel.id).label("file_count"),
func.coalesce(func.sum(FileMetadataModel.file_size), 0).label("total_size"),
)
.outerjoin(FileMetadataModel, (FileMetadataModel.source_id == SourceModel.id) & (FileMetadataModel.is_deleted == False))
.where(SourceModel.organization_id == actor.organization_id)
.where(SourceModel.is_deleted == False)
.group_by(SourceModel.id, SourceModel.name)
.order_by(SourceModel.name)
)
result = await session.execute(query)
source_aggregations = result.fetchall()
# Build response
metadata = OrganizationSourcesStats()
for row in source_aggregations:
source_id, source_name, file_count, total_size = row
if include_detailed_per_source_metadata:
# Get individual file details for this source
files_query = (
select(FileMetadataModel.id, FileMetadataModel.file_name, FileMetadataModel.file_size)
.where(
FileMetadataModel.source_id == source_id,
FileMetadataModel.organization_id == actor.organization_id,
FileMetadataModel.is_deleted == False,
)
.order_by(FileMetadataModel.file_name)
)
files_result = await session.execute(files_query)
files_rows = files_result.fetchall()
# Build file stats
files = [FileStats(file_id=file_row[0], file_name=file_row[1], file_size=file_row[2]) for file_row in files_rows]
# Build source metadata
source_metadata = SourceStats(
source_id=source_id, source_name=source_name, file_count=file_count, total_size=total_size, files=files
)
metadata.sources.append(source_metadata)
metadata.total_files += file_count
metadata.total_size += total_size
metadata.total_sources = len(source_aggregations)
return metadata
@enforce_types
@trace_method
async def get_files_by_ids_async(
self, file_ids: List[str], actor: PydanticUser, *, include_content: bool = False
) -> List[PydanticFileMetadata]:
"""
Get multiple files by their IDs in a single query.
Args:
file_ids: List of file IDs to retrieve
actor: User performing the action
include_content: Whether to include file content in the response
Returns:
List[PydanticFileMetadata]: List of files (may be fewer than requested if some don't exist)
"""
if not file_ids:
return []
async with db_registry.async_session() as session:
query = select(FileMetadataModel).where(
FileMetadataModel.id.in_(file_ids),
FileMetadataModel.organization_id == actor.organization_id,
FileMetadataModel.is_deleted == False,
)
# Eagerly load content if requested
if include_content:
query = query.options(selectinload(FileMetadataModel.content))
result = await session.execute(query)
files_orm = result.scalars().all()
if include_content:
return await bounded_gather([file.to_pydantic_async(include_content=include_content) for file in files_orm])
else:
return [file.to_pydantic() for file in files_orm]
@enforce_types
@trace_method
async def get_files_for_agents_async(
self, agent_ids: List[str], actor: PydanticUser, *, include_content: bool = False
) -> List[PydanticFileMetadata]:
"""
Get all files associated with the given agents via file-agent relationships.
Args:
agent_ids: List of agent IDs to find files for
actor: User performing the action
include_content: Whether to include file content in the response
Returns:
List[PydanticFileMetadata]: List of unique files associated with these agents
"""
if not agent_ids:
return []
async with db_registry.async_session() as session:
# We need to import FileAgent here to avoid circular imports
from letta.orm.files_agents import FileAgent as FileAgentModel
# Join through file-agent relationships
query = (
select(FileMetadataModel)
.join(FileAgentModel, FileMetadataModel.id == FileAgentModel.file_id)
.where(
FileAgentModel.agent_id.in_(agent_ids),
FileMetadataModel.organization_id == actor.organization_id,
FileMetadataModel.is_deleted == False,
FileAgentModel.is_deleted == False,
)
.distinct() # Ensure we don't get duplicate files
)
# Eagerly load content if requested
if include_content:
query = query.options(selectinload(FileMetadataModel.content))
result = await session.execute(query)
files_orm = result.scalars().all()
if include_content:
return await bounded_gather([file.to_pydantic_async(include_content=include_content) for file in files_orm])
else:
return [file.to_pydantic() for file in files_orm]