fix(sec): first pass of ensuring actor id is required everywhere (#9126)

first pass of ensuring actor id is required
This commit is contained in:
Kian Jones
2026-01-27 11:15:10 -08:00
committed by Caren Thomas
parent b34ad43691
commit 0099a95a43
14 changed files with 58 additions and 37 deletions

View File

@@ -42,7 +42,9 @@ def handle_db_timeout(func):
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
except QueryCanceledError as e:
logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
logger.error(
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
)
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
return wrapper
@@ -56,7 +58,9 @@ def handle_db_timeout(func):
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
except QueryCanceledError as e:
logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
logger.error(
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
)
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
return async_wrapper
@@ -207,6 +211,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
"""
Constructs the query for listing records.
"""
# Security check: if the model has organization_id column, actor should be provided
if actor is None and hasattr(cls, "organization_id"):
logger.warning(f"SECURITY: Listing org-scoped model {cls.__name__} without actor. This bypasses organization filtering.")
query = select(cls)
if join_model and join_conditions:
@@ -446,6 +454,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
):
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
# Security check: if the model has organization_id column, actor should be provided
# to ensure proper org-scoping. Log a warning if actor is None.
if actor is None and hasattr(cls, "organization_id"):
logger.warning(
f"SECURITY: Reading org-scoped model {cls.__name__} without actor. "
f"IDs: {identifiers}. This bypasses organization filtering."
)
# Start the query
query = select(cls)
# Collect query conditions for better error reporting
@@ -681,6 +697,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
**kwargs,
):
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
# Security check: if the model has organization_id column, actor should be provided
if actor is None and hasattr(cls, "organization_id"):
logger.warning(
f"SECURITY: Calculating size for org-scoped model {cls.__name__} without actor. This bypasses organization filtering."
)
query = select(func.count(1)).select_from(cls)
if actor:

View File

@@ -721,7 +721,7 @@ async def attach_source(
await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
if agent_state.enable_sleeptime:
source = await server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async")
return agent_state
@@ -748,7 +748,7 @@ async def attach_folder_to_agent(
await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
if agent_state.enable_sleeptime:
source = await server.source_manager.get_source_by_id(source_id=folder_id)
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async")
if is_1_0_sdk_version(headers):
@@ -779,7 +779,7 @@ async def detach_source(
if agent_state.enable_sleeptime:
try:
source = await server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
await server.block_manager.delete_block_async(block.id, actor)
except:
@@ -811,7 +811,7 @@ async def detach_folder_from_agent(
if agent_state.enable_sleeptime:
try:
source = await server.source_manager.get_source_by_id(source_id=folder_id)
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
await server.block_manager.delete_block_async(block.id, actor)
except:

View File

@@ -594,7 +594,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id:
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
source = await server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent in agents:
if agent.enable_sleeptime:

View File

@@ -231,7 +231,7 @@ async def list_messages_for_batch(
# Get messages directly using our efficient method
messages = await server.batch_manager.get_messages_for_letta_batch_async(
letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, ascending=(order == "asc"), before=before, after=after
letta_batch_job_id=batch_id, actor=actor, limit=limit, agent_id=agent_id, sort_descending=(order == "desc"), cursor=after
)
return LettaBatchMessages(messages=messages)

View File

@@ -485,7 +485,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id:
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
source = await server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent in agents:
if agent.enable_sleeptime:

View File

@@ -986,7 +986,7 @@ class SyncServer(object):
from letta.data_sources.connectors import DirectoryConnector
# TODO: move this into a thread
source = await self.source_manager.get_source_by_id(source_id=source_id)
source = await self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
connector = DirectoryConnector(input_files=[file_path])
num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
@@ -1225,7 +1225,7 @@ class SyncServer(object):
embedding_models=embedding_models,
organization_id=provider.organization_id,
)
await self.provider_manager.update_provider_last_synced_async(provider.id)
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
# Read from database
provider_llm_models = await self.provider_manager.list_models_async(
@@ -1307,7 +1307,7 @@ class SyncServer(object):
embedding_models=emb_models,
organization_id=provider.organization_id,
)
await self.provider_manager.update_provider_last_synced_async(provider.id)
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
# Read from database
provider_embedding_models = await self.provider_manager.list_models_async(

View File

@@ -30,10 +30,10 @@ from letta.schemas.agent_file import (
)
from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
from letta.schemas.file import FileMetadata
from letta.schemas.group import Group, GroupCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.mcp import MCPServer
from letta.schemas.message import Message
from letta.schemas.source import Source

View File

@@ -508,7 +508,7 @@ class BlockManager:
@enforce_types
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@trace_method
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
async def get_block_by_id_async(self, block_id: str, actor: PydanticUser) -> Optional[PydanticBlock]:
"""Retrieve a block by its ID, including tags."""
async with db_registry.async_session() as session:
try:
@@ -523,7 +523,7 @@ class BlockManager:
@enforce_types
@trace_method
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: PydanticUser) -> List[PydanticBlock]:
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
if not block_ids:
return []
@@ -540,9 +540,8 @@ class BlockManager:
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
)
# Apply access control if actor is provided
if actor:
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# Apply access control - actor is required for org-scoping
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# TODO: Add soft delete filter if applicable
# if hasattr(BlockModel, "is_deleted"):

View File

@@ -91,18 +91,17 @@ class FileManager:
await session.rollback()
return await self.get_file_by_id(file_metadata.id, actor=actor)
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
@trace_method
# @async_redis_cache(
# key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}",
# key_func=lambda self, file_id, actor, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id}:{include_content}:{strip_directory_prefix}",
# prefix="file_content",
# ttl_s=3600,
# model_class=PydanticFileMetadata,
# )
async def get_file_by_id(
self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False
self, file_id: str, actor: PydanticUser, *, include_content: bool = False, strip_directory_prefix: bool = False
) -> Optional[PydanticFileMetadata]:
"""Retrieve a file by its ID.
@@ -479,7 +478,7 @@ class FileManager:
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
"""Delete a file by its ID."""
async with db_registry.async_session() as session:
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor)
# invalidate cache for this file before deletion
await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id)

View File

@@ -63,7 +63,7 @@ class LLMBatchManager:
self,
llm_batch_id: str,
status: JobStatus,
actor: Optional[PydanticUser] = None,
actor: PydanticUser,
latest_polling_response: Optional[BetaMessageBatch] = None,
) -> PydanticLLMBatchJob:
"""Update a batch jobs status and optionally its polling response."""
@@ -107,8 +107,8 @@ class LLMBatchManager:
async def list_llm_batch_jobs_async(
self,
letta_batch_id: str,
actor: PydanticUser,
limit: Optional[int] = None,
actor: Optional[PydanticUser] = None,
after: Optional[str] = None,
) -> List[PydanticLLMBatchJob]:
"""
@@ -153,8 +153,8 @@ class LLMBatchManager:
async def get_messages_for_letta_batch_async(
self,
letta_batch_job_id: str,
actor: PydanticUser,
limit: int = 100,
actor: Optional[PydanticUser] = None,
agent_id: Optional[str] = None,
sort_descending: bool = True,
cursor: Optional[str] = None, # Message ID as cursor

View File

@@ -237,12 +237,16 @@ class ProviderManager:
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def update_provider_last_synced_async(self, provider_id: str) -> None:
"""Update the last_synced timestamp for a provider."""
async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None:
"""Update the last_synced timestamp for a provider.
Note: actor is optional to support system-level operations (e.g., during server initialization
for global providers). When actor is provided, org-scoping is enforced.
"""
from datetime import datetime, timezone
async with db_registry.async_session() as session:
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=None)
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
provider.last_synced = datetime.now(timezone.utc)
await session.commit()
@@ -533,7 +537,7 @@ class ProviderManager:
embedding_models=embedding_models,
organization_id=actor.organization_id,
)
await self.update_provider_last_synced_async(provider.id)
await self.update_provider_last_synced_async(provider.id, actor=actor)
except Exception as e:
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")

View File

@@ -167,9 +167,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
async def get_sandbox_config_by_type_async(
self, type: SandboxType, actor: Optional[PydanticUser] = None
) -> Optional[PydanticSandboxConfig]:
async def get_sandbox_config_by_type_async(self, type: SandboxType, actor: PydanticUser) -> Optional[PydanticSandboxConfig]:
"""Retrieve a sandbox config by its type."""
async with db_registry.async_session() as session:
try:
@@ -345,7 +343,7 @@ class SandboxConfigManager:
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
@trace_method
async def get_sandbox_env_var_by_key_and_sandbox_config_id_async(
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
self, key: str, sandbox_config_id: str, actor: PydanticUser
) -> Optional[PydanticEnvVar]:
"""Retrieve a sandbox environment variable by its key and sandbox_config_id."""
async with db_registry.async_session() as session:

View File

@@ -448,11 +448,10 @@ class SourceManager:
return list(agent_ids)
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
@trace_method
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
async def get_source_by_id(self, source_id: str, actor: PydanticUser) -> Optional[PydanticSource]:
"""Retrieve a source by its ID."""
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)

View File

@@ -2632,7 +2632,7 @@ async def test_byok_provider_last_synced_skips_sync_when_set(default_user, provi
)
# Set last_synced to indicate models are already synced
await provider_manager.update_provider_last_synced_async(byok_provider.id)
await provider_manager.update_provider_last_synced_async(byok_provider.id, actor=default_user)
# Create server
server = SyncServer(init_with_default_org_and_user=False)
@@ -2692,7 +2692,7 @@ async def test_base_provider_updates_last_synced_on_sync(default_user, provider_
embedding_models=[],
organization_id=None,
)
await provider_manager.update_provider_last_synced_async(base_provider.id)
await provider_manager.update_provider_last_synced_async(base_provider.id, actor=default_user)
# Verify last_synced was updated
updated_providers = await provider_manager.list_providers_async(name=base_provider.name, actor=default_user)