diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 1276684c..c012c54a 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -42,7 +42,9 @@ def handle_db_timeout(func): logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e) except QueryCanceledError as e: - logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") + logger.error( + f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}" + ) raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e) return wrapper @@ -56,7 +58,9 @@ def handle_db_timeout(func): logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e) except QueryCanceledError as e: - logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") + logger.error( + f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}" + ) raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e) return async_wrapper @@ -207,6 +211,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): """ Constructs the query for listing records. """ + # Security check: if the model has organization_id column, actor should be provided + if actor is None and hasattr(cls, "organization_id"): + logger.warning(f"SECURITY: Listing org-scoped model {cls.__name__} without actor. This bypasses organization filtering.") + query = select(cls) if join_model and join_conditions: @@ -446,6 +454,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): ): logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}") + # Security check: if the model has organization_id column, actor should be provided + # to ensure proper org-scoping. Log a warning if actor is None. + if actor is None and hasattr(cls, "organization_id"): + logger.warning( + f"SECURITY: Reading org-scoped model {cls.__name__} without actor. " + f"IDs: {identifiers}. This bypasses organization filtering." + ) + # Start the query query = select(cls) # Collect query conditions for better error reporting @@ -681,6 +697,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): **kwargs, ): logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}") + + # Security check: if the model has organization_id column, actor should be provided + if actor is None and hasattr(cls, "organization_id"): + logger.warning( + f"SECURITY: Calculating size for org-scoped model {cls.__name__} without actor. This bypasses organization filtering." + ) query = select(func.count(1)).select_from(cls) if actor: diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index e0abe150..d480b963 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -721,7 +721,7 @@ async def attach_source( await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor) if agent_state.enable_sleeptime: - source = await server.source_manager.get_source_by_id(source_id=source_id) + source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async") return agent_state @@ -748,7 +748,7 @@ async def attach_folder_to_agent( await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor) if agent_state.enable_sleeptime: - source = await server.source_manager.get_source_by_id(source_id=folder_id) + source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor) safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async") if is_1_0_sdk_version(headers): @@ -779,7 +779,7 @@ async def detach_source( if agent_state.enable_sleeptime: try: - source = await server.source_manager.get_source_by_id(source_id=source_id) + source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor) await server.block_manager.delete_block_async(block.id, actor) except: @@ -811,7 +811,7 @@ async def detach_folder_from_agent( if agent_state.enable_sleeptime: try: - source = await server.source_manager.get_source_by_id(source_id=folder_id) + source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor) block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor) await server.block_manager.delete_block_async(block.id, actor) except: diff --git a/letta/server/rest_api/routers/v1/folders.py b/letta/server/rest_api/routers/v1/folders.py index d925454c..908004ac 100644 --- a/letta/server/rest_api/routers/v1/folders.py +++ b/letta/server/rest_api/routers/v1/folders.py @@ -594,7 +594,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False): - source = await server.source_manager.get_source_by_id(source_id=source_id) + source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor) for agent in agents: if agent.enable_sleeptime: diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index dc6b0f9e..e695d292 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -231,7 +231,7 @@ async def list_messages_for_batch( # Get messages directly using our efficient method messages = await server.batch_manager.get_messages_for_letta_batch_async( - letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, ascending=(order == "asc"), before=before, after=after + letta_batch_job_id=batch_id, actor=actor, limit=limit, agent_id=agent_id, sort_descending=(order == "desc"), cursor=after ) return LettaBatchMessages(messages=messages) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index aad28074..d5a38a9c 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -485,7 +485,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False): - source = await server.source_manager.get_source_by_id(source_id=source_id) + source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor) for agent in agents: if agent.enable_sleeptime: diff --git a/letta/server/server.py b/letta/server/server.py index 7cb815a8..2197c38a 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -986,7 +986,7 @@ class SyncServer(object): from letta.data_sources.connectors import DirectoryConnector # TODO: move this into a thread - source = await self.source_manager.get_source_by_id(source_id=source_id) + source = await self.source_manager.get_source_by_id(source_id=source_id, actor=actor) connector = DirectoryConnector(input_files=[file_path]) num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector) @@ -1225,7 +1225,7 @@ class SyncServer(object): embedding_models=embedding_models, organization_id=provider.organization_id, ) - await self.provider_manager.update_provider_last_synced_async(provider.id) + await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor) # Read from database provider_llm_models = await self.provider_manager.list_models_async( @@ -1307,7 +1307,7 @@ class SyncServer(object): embedding_models=emb_models, organization_id=provider.organization_id, ) - await self.provider_manager.update_provider_last_synced_async(provider.id) + await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor) # Read from database provider_embedding_models = await self.provider_manager.list_models_async( diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 46f39ca1..1947a8ee 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -30,10 +30,10 @@ from letta.schemas.agent_file import ( ) from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig from letta.schemas.enums import FileProcessingStatus, VectorDBProvider from letta.schemas.file import FileMetadata from letta.schemas.group import Group, GroupCreate +from letta.schemas.llm_config import LLMConfig from letta.schemas.mcp import MCPServer from letta.schemas.message import Message from letta.schemas.source import Source diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index cdd219b0..848c4868 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -508,7 +508,7 @@ class BlockManager: @enforce_types @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) @trace_method - async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: + async def get_block_by_id_async(self, block_id: str, actor: PydanticUser) -> Optional[PydanticBlock]: """Retrieve a block by its ID, including tags.""" async with db_registry.async_session() as session: try: @@ -523,7 +523,7 @@ class BlockManager: @enforce_types @trace_method - async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: + async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: PydanticUser) -> List[PydanticBlock]: """Retrieve blocks by their ids without loading unnecessary relationships. Async implementation.""" if not block_ids: return [] @@ -540,9 +540,8 @@ class BlockManager: noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags) ) - # Apply access control if actor is provided - if actor: - query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) + # Apply access control - actor is required for org-scoping + query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) # TODO: Add soft delete filter if applicable # if hasattr(BlockModel, "is_deleted"): diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index bd52e309..ee3db939 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -91,18 +91,17 @@ class FileManager: await session.rollback() return await self.get_file_by_id(file_metadata.id, actor=actor) - # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types @raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE) @trace_method # @async_redis_cache( - # key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}", + # key_func=lambda self, file_id, actor, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id}:{include_content}:{strip_directory_prefix}", # prefix="file_content", # ttl_s=3600, # model_class=PydanticFileMetadata, # ) async def get_file_by_id( - self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False + self, file_id: str, actor: PydanticUser, *, include_content: bool = False, strip_directory_prefix: bool = False ) -> Optional[PydanticFileMetadata]: """Retrieve a file by its ID. @@ -479,7 +478,7 @@ class FileManager: async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata: """Delete a file by its ID.""" async with db_registry.async_session() as session: - file = await FileMetadataModel.read_async(db_session=session, identifier=file_id) + file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor) # invalidate cache for this file before deletion await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id) diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index 6c09c2be..d544adf3 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -63,7 +63,7 @@ class LLMBatchManager: self, llm_batch_id: str, status: JobStatus, - actor: Optional[PydanticUser] = None, + actor: PydanticUser, latest_polling_response: Optional[BetaMessageBatch] = None, ) -> PydanticLLMBatchJob: """Update a batch job’s status and optionally its polling response.""" @@ -107,8 +107,8 @@ class LLMBatchManager: async def list_llm_batch_jobs_async( self, letta_batch_id: str, + actor: PydanticUser, limit: Optional[int] = None, - actor: Optional[PydanticUser] = None, after: Optional[str] = None, ) -> List[PydanticLLMBatchJob]: """ @@ -153,8 +153,8 @@ class LLMBatchManager: async def get_messages_for_letta_batch_async( self, letta_batch_job_id: str, + actor: PydanticUser, limit: int = 100, - actor: Optional[PydanticUser] = None, agent_id: Optional[str] = None, sort_descending: bool = True, cursor: Optional[str] = None, # Message ID as cursor diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 4628c5eb..47fc124b 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -237,12 +237,16 @@ class ProviderManager: @enforce_types @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) - async def update_provider_last_synced_async(self, provider_id: str) -> None: - """Update the last_synced timestamp for a provider.""" + async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None: + """Update the last_synced timestamp for a provider. + + Note: actor is optional to support system-level operations (e.g., during server initialization + for global providers). When actor is provided, org-scoping is enforced. + """ from datetime import datetime, timezone async with db_registry.async_session() as session: - provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=None) + provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor) provider.last_synced = datetime.now(timezone.utc) await session.commit() @@ -533,7 +537,7 @@ class ProviderManager: embedding_models=embedding_models, organization_id=actor.organization_id, ) - await self.update_provider_last_synced_async(provider.id) + await self.update_provider_last_synced_async(provider.id, actor=actor) except Exception as e: logger.error(f"Failed to sync models for provider '{provider.name}': {e}") diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index c34611e3..30849870 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -167,9 +167,7 @@ class SandboxConfigManager: @enforce_types @trace_method - async def get_sandbox_config_by_type_async( - self, type: SandboxType, actor: Optional[PydanticUser] = None - ) -> Optional[PydanticSandboxConfig]: + async def get_sandbox_config_by_type_async(self, type: SandboxType, actor: PydanticUser) -> Optional[PydanticSandboxConfig]: """Retrieve a sandbox config by its type.""" async with db_registry.async_session() as session: try: @@ -345,7 +343,7 @@ class SandboxConfigManager: @raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG) @trace_method async def get_sandbox_env_var_by_key_and_sandbox_config_id_async( - self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None + self, key: str, sandbox_config_id: str, actor: PydanticUser ) -> Optional[PydanticEnvVar]: """Retrieve a sandbox environment variable by its key and sandbox_config_id.""" async with db_registry.async_session() as session: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index b45c9128..6f1891e7 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -448,11 +448,10 @@ class SourceManager: return list(agent_ids) - # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) @trace_method - async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]: + async def get_source_by_id(self, source_id: str, actor: PydanticUser) -> Optional[PydanticSource]: """Retrieve a source by its ID.""" async with db_registry.async_session() as session: source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor) diff --git a/tests/test_server_providers.py b/tests/test_server_providers.py index 10579fc7..e628497d 100644 --- a/tests/test_server_providers.py +++ b/tests/test_server_providers.py @@ -2632,7 +2632,7 @@ async def test_byok_provider_last_synced_skips_sync_when_set(default_user, provi ) # Set last_synced to indicate models are already synced - await provider_manager.update_provider_last_synced_async(byok_provider.id) + await provider_manager.update_provider_last_synced_async(byok_provider.id, actor=default_user) # Create server server = SyncServer(init_with_default_org_and_user=False) @@ -2692,7 +2692,7 @@ async def test_base_provider_updates_last_synced_on_sync(default_user, provider_ embedding_models=[], organization_id=None, ) - await provider_manager.update_provider_last_synced_async(base_provider.id) + await provider_manager.update_provider_last_synced_async(base_provider.id, actor=default_user) # Verify last_synced was updated updated_providers = await provider_manager.list_providers_async(name=base_provider.name, actor=default_user)