import asyncio import os from datetime import datetime from typing import List, Optional from sqlalchemy import func, select, update from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import selectinload from letta.constants import MAX_FILENAME_LENGTH from letta.orm.errors import NoResultFound from letta.orm.file import FileContent as FileContentModel from letta.orm.file import FileMetadata as FileMetadataModel from letta.orm.sqlalchemy_base import AccessType from letta.otel.tracing import trace_method from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.source import Source as PydanticSource from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, SourceStats from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types class DuplicateFileError(Exception): """Raised when a duplicate file is encountered and error handling is specified""" def __init__(self, filename: str, source_name: str): self.filename = filename self.source_name = source_name super().__init__(f"File '{filename}' already exists in source '{source_name}'") class FileManager: """Manager class to handle business logic related to files.""" async def _invalidate_file_caches(self, file_id: str, actor: PydanticUser, original_filename: str = None, source_id: str = None): """Invalidate all caches related to a file.""" # TEMPORARILY DISABLED - caching is disabled # # invalidate file content cache (all variants) # await self.get_file_by_id.cache_invalidate(self, file_id, actor, include_content=True) # await self.get_file_by_id.cache_invalidate(self, file_id, actor, include_content=False) # # invalidate filename-based cache if we have the info # if original_filename and source_id: # await self.get_file_by_original_name_and_source.cache_invalidate(self, original_filename, source_id, actor) @enforce_types @trace_method async def create_file( self, file_metadata: PydanticFileMetadata, actor: PydanticUser, *, text: Optional[str] = None, ) -> PydanticFileMetadata: # short-circuit if it already exists existing = await self.get_file_by_id(file_metadata.id, actor=actor) if existing: return existing async with db_registry.async_session() as session: try: file_metadata.organization_id = actor.organization_id file_orm = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True)) await file_orm.create_async(session, actor=actor, no_commit=True) if text is not None: content_orm = FileContentModel(file_id=file_orm.id, text=text) await content_orm.create_async(session, actor=actor, no_commit=True) await session.commit() await session.refresh(file_orm) # invalidate cache for this new file await self._invalidate_file_caches(file_orm.id, actor, file_orm.original_file_name, file_orm.source_id) return await file_orm.to_pydantic_async() except IntegrityError: await session.rollback() return await self.get_file_by_id(file_metadata.id, actor=actor) # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types @trace_method # @async_redis_cache( # key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}", # prefix="file_content", # ttl_s=3600, # model_class=PydanticFileMetadata, # ) async def get_file_by_id( self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False ) -> Optional[PydanticFileMetadata]: """Retrieve a file by its ID. If `include_content=True`, the FileContent relationship is eagerly loaded so `to_pydantic(include_content=True)` never triggers a lazy SELECT (avoids MissingGreenlet). """ async with db_registry.async_session() as session: try: if include_content: # explicit eager load query = ( select(FileMetadataModel).where(FileMetadataModel.id == file_id).options(selectinload(FileMetadataModel.content)) ) # apply org-scoping if actor provided if actor: query = FileMetadataModel.apply_access_predicate( query, actor, access=["read"], access_type=AccessType.ORGANIZATION, ) result = await session.execute(query) file_orm = result.scalar_one() else: # fast path (metadata only) file_orm = await FileMetadataModel.read_async( db_session=session, identifier=file_id, actor=actor, ) return await file_orm.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix) except NoResultFound: return None @enforce_types @trace_method async def update_file_status( self, *, file_id: str, actor: PydanticUser, processing_status: Optional[FileProcessingStatus] = None, error_message: Optional[str] = None, total_chunks: Optional[int] = None, chunks_embedded: Optional[int] = None, enforce_state_transitions: bool = True, ) -> Optional[PydanticFileMetadata]: """ Update processing_status, error_message, total_chunks, and/or chunks_embedded on a FileMetadata row. Enforces state transition rules (when enforce_state_transitions=True): - PENDING -> PARSING -> EMBEDDING -> COMPLETED (normal flow) - Any non-terminal state -> ERROR - ERROR and COMPLETED are terminal (no transitions allowed) Args: file_id: ID of the file to update actor: User performing the update processing_status: New processing status to set error_message: Error message to set (if any) total_chunks: Total number of chunks in the file chunks_embedded: Number of chunks already embedded enforce_state_transitions: Whether to enforce state transition rules (default: True). Set to False to bypass validation for testing or special cases. Returns: Updated file metadata, or None if the update was blocked * 1st round-trip → UPDATE with optional state validation * 2nd round-trip → SELECT fresh row (same as read_async) if update succeeded """ if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None: raise ValueError("Nothing to update") values: dict[str, object] = {"updated_at": datetime.utcnow()} if processing_status is not None: values["processing_status"] = processing_status if error_message is not None: values["error_message"] = error_message if total_chunks is not None: values["total_chunks"] = total_chunks if chunks_embedded is not None: values["chunks_embedded"] = chunks_embedded # validate state transitions before making any database calls if enforce_state_transitions and processing_status == FileProcessingStatus.PENDING: # PENDING cannot be set after initial creation raise ValueError(f"Cannot transition to PENDING state for file {file_id} - PENDING is only valid as initial state") async with db_registry.async_session() as session: # build where conditions where_conditions = [ FileMetadataModel.id == file_id, FileMetadataModel.organization_id == actor.organization_id, ] # only add state transition validation if enforce_state_transitions is True if enforce_state_transitions: # prevent updates to terminal states (ERROR, COMPLETED) where_conditions.append( FileMetadataModel.processing_status.notin_([FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED]) ) if processing_status is not None: # enforce specific transitions based on target status if processing_status == FileProcessingStatus.PARSING: where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PENDING) elif processing_status == FileProcessingStatus.EMBEDDING: where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.PARSING) elif processing_status == FileProcessingStatus.COMPLETED: where_conditions.append(FileMetadataModel.processing_status == FileProcessingStatus.EMBEDDING) # ERROR can be set from any non-terminal state (already handled by terminal check above) # fast in-place update with state validation stmt = ( update(FileMetadataModel) .where(*where_conditions) .values(**values) .returning(FileMetadataModel.id) # return id if update succeeded ) result = await session.execute(stmt) updated_id = result.scalar() if not updated_id: # update was blocked await session.commit() if enforce_state_transitions: # update was blocked by state transition rules - raise error # fetch current state to provide informative error current_file = await FileMetadataModel.read_async( db_session=session, identifier=file_id, actor=actor, ) current_status = current_file.processing_status # build informative error message if processing_status is not None: if current_status in [FileProcessingStatus.ERROR, FileProcessingStatus.COMPLETED]: raise ValueError( f"Cannot update file {file_id} status from terminal state {current_status} to {processing_status}" ) else: raise ValueError(f"Invalid state transition for file {file_id}: {current_status} -> {processing_status}") else: raise ValueError(f"Cannot update file {file_id} in terminal state {current_status}") else: # validation was bypassed but update still failed (e.g., file doesn't exist) return None await session.commit() # invalidate cache for this file await self._invalidate_file_caches(file_id, actor) # reload via normal accessor so we return a fully-attached object file_orm = await FileMetadataModel.read_async( db_session=session, identifier=file_id, actor=actor, ) return await file_orm.to_pydantic_async() @enforce_types @trace_method async def upsert_file_content( self, *, file_id: str, text: str, actor: PydanticUser, ) -> PydanticFileMetadata: async with db_registry.async_session() as session: await FileMetadataModel.read_async(session, file_id, actor) dialect_name = session.bind.dialect.name if dialect_name == "postgresql": stmt = ( pg_insert(FileContentModel) .values(file_id=file_id, text=text) .on_conflict_do_update( index_elements=[FileContentModel.file_id], set_={"text": text}, ) ) await session.execute(stmt) else: # Emulate upsert for SQLite and others stmt = select(FileContentModel).where(FileContentModel.file_id == file_id) result = await session.execute(stmt) existing = result.scalar_one_or_none() if existing: await session.execute(update(FileContentModel).where(FileContentModel.file_id == file_id).values(text=text)) else: session.add(FileContentModel(file_id=file_id, text=text)) await session.commit() # invalidate cache for this file since content changed await self._invalidate_file_caches(file_id, actor) # Reload with content query = select(FileMetadataModel).options(selectinload(FileMetadataModel.content)).where(FileMetadataModel.id == file_id) result = await session.execute(query) return await result.scalar_one().to_pydantic_async(include_content=True) @enforce_types @trace_method async def list_files( self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, include_content: bool = False, strip_directory_prefix: bool = False, ) -> List[PydanticFileMetadata]: """List all files with optional pagination.""" async with db_registry.async_session() as session: options = [selectinload(FileMetadataModel.content)] if include_content else None files = await FileMetadataModel.list_async( db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id, query_options=options, ) return [ await file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix) for file in files ] @enforce_types @trace_method async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata: """Delete a file by its ID.""" async with db_registry.async_session() as session: file = await FileMetadataModel.read_async(db_session=session, identifier=file_id) # invalidate cache for this file before deletion await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id) await file.hard_delete_async(db_session=session, actor=actor) return await file.to_pydantic_async() @enforce_types @trace_method async def generate_unique_filename(self, original_filename: str, source: PydanticSource, organization_id: str) -> str: """ Generate a unique filename by adding a numeric suffix if duplicates exist. Always returns a unique filename - does not handle duplicate policies. Parameters: original_filename (str): The original filename as uploaded. source (PydanticSource): Source to check for duplicates within. organization_id (str): Organization ID to check for duplicates within. Returns: str: A unique filename with source.name prefix and numeric suffix if needed. """ base, ext = os.path.splitext(original_filename) # Reserve space for potential suffix: " (999)" = 6 characters max_base_length = MAX_FILENAME_LENGTH - len(ext) - 6 if len(base) > max_base_length: base = base[:max_base_length] original_filename = f"{base}{ext}" async with db_registry.async_session() as session: # Count existing files with the same original_file_name in this source query = select(func.count(FileMetadataModel.id)).where( FileMetadataModel.original_file_name == original_filename, FileMetadataModel.source_id == source.id, FileMetadataModel.organization_id == organization_id, FileMetadataModel.is_deleted == False, ) result = await session.execute(query) count = result.scalar() or 0 if count == 0: # No duplicates, return original filename with source.name return f"{source.name}/{original_filename}" else: # Add numeric suffix to make unique return f"{source.name}/{base}_({count}){ext}" @enforce_types @trace_method # @async_redis_cache( # key_func=lambda self, original_filename, source_id, actor: f"{original_filename}:{source_id}:{actor.organization_id}", # prefix="file_by_name", # ttl_s=3600, # model_class=PydanticFileMetadata, # ) async def get_file_by_original_name_and_source( self, original_filename: str, source_id: str, actor: PydanticUser ) -> Optional[PydanticFileMetadata]: """ Get a file by its original filename and source ID. Parameters: original_filename (str): The original filename to search for. source_id (str): The source ID to search within. actor (PydanticUser): The actor performing the request. Returns: Optional[PydanticFileMetadata]: The file metadata if found, None otherwise. """ async with db_registry.async_session() as session: query = ( select(FileMetadataModel) .where( FileMetadataModel.original_file_name == original_filename, FileMetadataModel.source_id == source_id, FileMetadataModel.organization_id == actor.organization_id, FileMetadataModel.is_deleted == False, ) .limit(1) ) result = await session.execute(query) file_orm = result.scalar_one_or_none() if file_orm: return await file_orm.to_pydantic_async() return None @enforce_types @trace_method async def get_organization_sources_metadata( self, actor: PydanticUser, include_detailed_per_source_metadata: bool = False ) -> OrganizationSourcesStats: """ Get aggregated metadata for all sources in an organization with optimized queries. Returns structured metadata including: - Total number of sources - Total number of files across all sources - Total size of all files - Per-source breakdown with file details (if include_detailed_per_source_metadata is True) """ async with db_registry.async_session() as session: # Import here to avoid circular imports from letta.orm.source import Source as SourceModel # Single optimized query to get all sources with their file aggregations query = ( select( SourceModel.id, SourceModel.name, func.count(FileMetadataModel.id).label("file_count"), func.coalesce(func.sum(FileMetadataModel.file_size), 0).label("total_size"), ) .outerjoin(FileMetadataModel, (FileMetadataModel.source_id == SourceModel.id) & (FileMetadataModel.is_deleted == False)) .where(SourceModel.organization_id == actor.organization_id) .where(SourceModel.is_deleted == False) .group_by(SourceModel.id, SourceModel.name) .order_by(SourceModel.name) ) result = await session.execute(query) source_aggregations = result.fetchall() # Build response metadata = OrganizationSourcesStats() for row in source_aggregations: source_id, source_name, file_count, total_size = row if include_detailed_per_source_metadata: # Get individual file details for this source files_query = ( select(FileMetadataModel.id, FileMetadataModel.file_name, FileMetadataModel.file_size) .where( FileMetadataModel.source_id == source_id, FileMetadataModel.organization_id == actor.organization_id, FileMetadataModel.is_deleted == False, ) .order_by(FileMetadataModel.file_name) ) files_result = await session.execute(files_query) files_rows = files_result.fetchall() # Build file stats files = [FileStats(file_id=file_row[0], file_name=file_row[1], file_size=file_row[2]) for file_row in files_rows] # Build source metadata source_metadata = SourceStats( source_id=source_id, source_name=source_name, file_count=file_count, total_size=total_size, files=files ) metadata.sources.append(source_metadata) metadata.total_files += file_count metadata.total_size += total_size metadata.total_sources = len(source_aggregations) return metadata @enforce_types @trace_method async def get_files_by_ids_async( self, file_ids: List[str], actor: PydanticUser, *, include_content: bool = False ) -> List[PydanticFileMetadata]: """ Get multiple files by their IDs in a single query. Args: file_ids: List of file IDs to retrieve actor: User performing the action include_content: Whether to include file content in the response Returns: List[PydanticFileMetadata]: List of files (may be fewer than requested if some don't exist) """ if not file_ids: return [] async with db_registry.async_session() as session: query = select(FileMetadataModel).where( FileMetadataModel.id.in_(file_ids), FileMetadataModel.organization_id == actor.organization_id, FileMetadataModel.is_deleted == False, ) # Eagerly load content if requested if include_content: query = query.options(selectinload(FileMetadataModel.content)) result = await session.execute(query) files_orm = result.scalars().all() return await asyncio.gather(*[file.to_pydantic_async(include_content=include_content) for file in files_orm]) @enforce_types @trace_method async def get_files_for_agents_async( self, agent_ids: List[str], actor: PydanticUser, *, include_content: bool = False ) -> List[PydanticFileMetadata]: """ Get all files associated with the given agents via file-agent relationships. Args: agent_ids: List of agent IDs to find files for actor: User performing the action include_content: Whether to include file content in the response Returns: List[PydanticFileMetadata]: List of unique files associated with these agents """ if not agent_ids: return [] async with db_registry.async_session() as session: # We need to import FileAgent here to avoid circular imports from letta.orm.file_agent import FileAgent as FileAgentModel # Join through file-agent relationships query = ( select(FileMetadataModel) .join(FileAgentModel, FileMetadataModel.id == FileAgentModel.file_id) .where( FileAgentModel.agent_id.in_(agent_ids), FileMetadataModel.organization_id == actor.organization_id, FileMetadataModel.is_deleted == False, FileAgentModel.is_deleted == False, ) .distinct() # Ensure we don't get duplicate files ) # Eagerly load content if requested if include_content: query = query.options(selectinload(FileMetadataModel.content)) result = await session.execute(query) files_orm = result.scalars().all() return await asyncio.gather(*[file.to_pydantic_async(include_content=include_content) for file in files_orm])