diff --git a/letta/helpers/decorators.py b/letta/helpers/decorators.py index bfcb0665..77744ea1 100644 --- a/letta/helpers/decorators.py +++ b/letta/helpers/decorators.py @@ -152,7 +152,7 @@ def async_redis_cache( def get_cache_key(*args, **kwargs): return f"{prefix}:{key_func(*args, **kwargs)}" - # async_wrapper.cache_invalidate = invalidate + async_wrapper.cache_invalidate = invalidate async_wrapper.cache_key_func = get_cache_key async_wrapper.cache_stats = stats return async_wrapper diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index 530fa3e1..6aa86b16 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -8,6 +8,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import selectinload from letta.constants import MAX_FILENAME_LENGTH +from letta.helpers.decorators import async_redis_cache from letta.orm.errors import NoResultFound from letta.orm.file import FileContent as FileContentModel from letta.orm.file import FileMetadata as FileMetadataModel @@ -34,6 +35,16 @@ class DuplicateFileError(Exception): class FileManager: """Manager class to handle business logic related to files.""" + async def _invalidate_file_caches(self, file_id: str, actor: PydanticUser, original_filename: str = None, source_id: str = None): + """Invalidate all caches related to a file.""" + # invalidate file content cache (all variants) + await self.get_file_by_id.cache_invalidate(self, file_id, actor, include_content=True) + await self.get_file_by_id.cache_invalidate(self, file_id, actor, include_content=False) + + # invalidate filename-based cache if we have the info + if original_filename and source_id: + await self.get_file_by_original_name_and_source.cache_invalidate(self, original_filename, source_id, actor) + @enforce_types @trace_method async def create_file( @@ -61,6 +72,10 @@ class FileManager: await session.commit() await session.refresh(file_orm) + + # invalidate cache for this new file + await self._invalidate_file_caches(file_orm.id, actor, file_orm.original_file_name, file_orm.source_id) + return await file_orm.to_pydantic_async() except IntegrityError: @@ -70,6 +85,12 @@ class FileManager: # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types @trace_method + @async_redis_cache( + key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}", + prefix="file_content", + ttl_s=3600, + model_class=PydanticFileMetadata, + ) async def get_file_by_id( self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False ) -> Optional[PydanticFileMetadata]: @@ -155,6 +176,9 @@ class FileManager: await session.execute(stmt) await session.commit() + # invalidate cache for this file + await self._invalidate_file_caches(file_id, actor) + # Reload via normal accessor so we return a fully-attached object file_orm = await FileMetadataModel.read_async( db_session=session, @@ -200,6 +224,9 @@ class FileManager: await session.commit() + # invalidate cache for this file since content changed + await self._invalidate_file_caches(file_id, actor) + # Reload with content query = select(FileMetadataModel).options(selectinload(FileMetadataModel.content)).where(FileMetadataModel.id == file_id) result = await session.execute(query) @@ -239,6 +266,10 @@ class FileManager: """Delete a file by its ID.""" async with db_registry.async_session() as session: file = await FileMetadataModel.read_async(db_session=session, identifier=file_id) + + # invalidate cache for this file before deletion + await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id) + await file.hard_delete_async(db_session=session, actor=actor) return await file.to_pydantic_async() @@ -285,6 +316,12 @@ class FileManager: @enforce_types @trace_method + @async_redis_cache( + key_func=lambda self, original_filename, source_id, actor: f"{original_filename}:{source_id}:{actor.organization_id}", + prefix="file_by_name", + ttl_s=3600, + model_class=PydanticFileMetadata, + ) async def get_file_by_original_name_and_source( self, original_filename: str, source_id: str, actor: PydanticUser ) -> Optional[PydanticFileMetadata]: