fix(sec): first pass of ensuring actor id is required everywhere (#9126)
first pass of ensuring actor id is required
This commit is contained in:
@@ -42,7 +42,9 @@ def handle_db_timeout(func):
|
||||
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
||||
except QueryCanceledError as e:
|
||||
logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
logger.error(
|
||||
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
|
||||
)
|
||||
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
|
||||
|
||||
return wrapper
|
||||
@@ -56,7 +58,9 @@ def handle_db_timeout(func):
|
||||
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
||||
except QueryCanceledError as e:
|
||||
logger.error(f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
||||
logger.error(
|
||||
f"Query canceled (statement timeout) while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}"
|
||||
)
|
||||
raise DatabaseTimeoutError(message=f"Query canceled due to statement timeout in {func.__name__}.", original_exception=e)
|
||||
|
||||
return async_wrapper
|
||||
@@ -207,6 +211,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
"""
|
||||
Constructs the query for listing records.
|
||||
"""
|
||||
# Security check: if the model has organization_id column, actor should be provided
|
||||
if actor is None and hasattr(cls, "organization_id"):
|
||||
logger.warning(f"SECURITY: Listing org-scoped model {cls.__name__} without actor. This bypasses organization filtering.")
|
||||
|
||||
query = select(cls)
|
||||
|
||||
if join_model and join_conditions:
|
||||
@@ -446,6 +454,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
):
|
||||
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
|
||||
|
||||
# Security check: if the model has organization_id column, actor should be provided
|
||||
# to ensure proper org-scoping. Log a warning if actor is None.
|
||||
if actor is None and hasattr(cls, "organization_id"):
|
||||
logger.warning(
|
||||
f"SECURITY: Reading org-scoped model {cls.__name__} without actor. "
|
||||
f"IDs: {identifiers}. This bypasses organization filtering."
|
||||
)
|
||||
|
||||
# Start the query
|
||||
query = select(cls)
|
||||
# Collect query conditions for better error reporting
|
||||
@@ -681,6 +697,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
**kwargs,
|
||||
):
|
||||
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
||||
|
||||
# Security check: if the model has organization_id column, actor should be provided
|
||||
if actor is None and hasattr(cls, "organization_id"):
|
||||
logger.warning(
|
||||
f"SECURITY: Calculating size for org-scoped model {cls.__name__} without actor. This bypasses organization filtering."
|
||||
)
|
||||
query = select(func.count(1)).select_from(cls)
|
||||
|
||||
if actor:
|
||||
|
||||
@@ -721,7 +721,7 @@ async def attach_source(
|
||||
await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async")
|
||||
|
||||
return agent_state
|
||||
@@ -748,7 +748,7 @@ async def attach_folder_to_agent(
|
||||
await server.agent_manager.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
|
||||
safe_create_task(server.sleeptime_document_ingest_async(agent_state, source, actor), label="sleeptime_document_ingest_async")
|
||||
|
||||
if is_1_0_sdk_version(headers):
|
||||
@@ -779,7 +779,7 @@ async def detach_source(
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
try:
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
|
||||
await server.block_manager.delete_block_async(block.id, actor)
|
||||
except:
|
||||
@@ -811,7 +811,7 @@ async def detach_folder_from_agent(
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
try:
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
|
||||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
|
||||
await server.block_manager.delete_block_async(block.id, actor)
|
||||
except:
|
||||
|
||||
@@ -594,7 +594,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id:
|
||||
|
||||
|
||||
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
for agent in agents:
|
||||
if agent.enable_sleeptime:
|
||||
|
||||
@@ -231,7 +231,7 @@ async def list_messages_for_batch(
|
||||
|
||||
# Get messages directly using our efficient method
|
||||
messages = await server.batch_manager.get_messages_for_letta_batch_async(
|
||||
letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, ascending=(order == "asc"), before=before, after=after
|
||||
letta_batch_job_id=batch_id, actor=actor, limit=limit, agent_id=agent_id, sort_descending=(order == "desc"), cursor=after
|
||||
)
|
||||
|
||||
return LettaBatchMessages(messages=messages)
|
||||
|
||||
@@ -485,7 +485,7 @@ async def load_file_to_source_async(server: SyncServer, source_id: str, job_id:
|
||||
|
||||
|
||||
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
for agent in agents:
|
||||
if agent.enable_sleeptime:
|
||||
|
||||
@@ -986,7 +986,7 @@ class SyncServer(object):
|
||||
from letta.data_sources.connectors import DirectoryConnector
|
||||
|
||||
# TODO: move this into a thread
|
||||
source = await self.source_manager.get_source_by_id(source_id=source_id)
|
||||
source = await self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
connector = DirectoryConnector(input_files=[file_path])
|
||||
num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
|
||||
|
||||
@@ -1225,7 +1225,7 @@ class SyncServer(object):
|
||||
embedding_models=embedding_models,
|
||||
organization_id=provider.organization_id,
|
||||
)
|
||||
await self.provider_manager.update_provider_last_synced_async(provider.id)
|
||||
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
|
||||
|
||||
# Read from database
|
||||
provider_llm_models = await self.provider_manager.list_models_async(
|
||||
@@ -1307,7 +1307,7 @@ class SyncServer(object):
|
||||
embedding_models=emb_models,
|
||||
organization_id=provider.organization_id,
|
||||
)
|
||||
await self.provider_manager.update_provider_last_synced_async(provider.id)
|
||||
await self.provider_manager.update_provider_last_synced_async(provider.id, actor=actor)
|
||||
|
||||
# Read from database
|
||||
provider_embedding_models = await self.provider_manager.list_models_async(
|
||||
|
||||
@@ -30,10 +30,10 @@ from letta.schemas.agent_file import (
|
||||
)
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.enums import FileProcessingStatus, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.group import Group, GroupCreate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.mcp import MCPServer
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.source import Source
|
||||
|
||||
@@ -508,7 +508,7 @@ class BlockManager:
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@trace_method
|
||||
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
async def get_block_by_id_async(self, block_id: str, actor: PydanticUser) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its ID, including tags."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
@@ -523,7 +523,7 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: PydanticUser) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
|
||||
if not block_ids:
|
||||
return []
|
||||
@@ -540,9 +540,8 @@ class BlockManager:
|
||||
noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups), noload(BlockModel.tags)
|
||||
)
|
||||
|
||||
# Apply access control if actor is provided
|
||||
if actor:
|
||||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
# Apply access control - actor is required for org-scoping
|
||||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
|
||||
# TODO: Add soft delete filter if applicable
|
||||
# if hasattr(BlockModel, "is_deleted"):
|
||||
|
||||
@@ -91,18 +91,17 @@ class FileManager:
|
||||
await session.rollback()
|
||||
return await self.get_file_by_id(file_metadata.id, actor=actor)
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
@trace_method
|
||||
# @async_redis_cache(
|
||||
# key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}",
|
||||
# key_func=lambda self, file_id, actor, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id}:{include_content}:{strip_directory_prefix}",
|
||||
# prefix="file_content",
|
||||
# ttl_s=3600,
|
||||
# model_class=PydanticFileMetadata,
|
||||
# )
|
||||
async def get_file_by_id(
|
||||
self, file_id: str, actor: Optional[PydanticUser] = None, *, include_content: bool = False, strip_directory_prefix: bool = False
|
||||
self, file_id: str, actor: PydanticUser, *, include_content: bool = False, strip_directory_prefix: bool = False
|
||||
) -> Optional[PydanticFileMetadata]:
|
||||
"""Retrieve a file by its ID.
|
||||
|
||||
@@ -479,7 +478,7 @@ class FileManager:
|
||||
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
|
||||
"""Delete a file by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
|
||||
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor)
|
||||
|
||||
# invalidate cache for this file before deletion
|
||||
await self._invalidate_file_caches(file_id, actor, file.original_file_name, file.source_id)
|
||||
|
||||
@@ -63,7 +63,7 @@ class LLMBatchManager:
|
||||
self,
|
||||
llm_batch_id: str,
|
||||
status: JobStatus,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
actor: PydanticUser,
|
||||
latest_polling_response: Optional[BetaMessageBatch] = None,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Update a batch job’s status and optionally its polling response."""
|
||||
@@ -107,8 +107,8 @@ class LLMBatchManager:
|
||||
async def list_llm_batch_jobs_async(
|
||||
self,
|
||||
letta_batch_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: Optional[int] = None,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> List[PydanticLLMBatchJob]:
|
||||
"""
|
||||
@@ -153,8 +153,8 @@ class LLMBatchManager:
|
||||
async def get_messages_for_letta_batch_async(
|
||||
self,
|
||||
letta_batch_job_id: str,
|
||||
actor: PydanticUser,
|
||||
limit: int = 100,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
sort_descending: bool = True,
|
||||
cursor: Optional[str] = None, # Message ID as cursor
|
||||
|
||||
@@ -237,12 +237,16 @@ class ProviderManager:
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
async def update_provider_last_synced_async(self, provider_id: str) -> None:
|
||||
"""Update the last_synced timestamp for a provider."""
|
||||
async def update_provider_last_synced_async(self, provider_id: str, actor: Optional[PydanticUser] = None) -> None:
|
||||
"""Update the last_synced timestamp for a provider.
|
||||
|
||||
Note: actor is optional to support system-level operations (e.g., during server initialization
|
||||
for global providers). When actor is provided, org-scoping is enforced.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=None)
|
||||
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
|
||||
provider.last_synced = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
@@ -533,7 +537,7 @@ class ProviderManager:
|
||||
embedding_models=embedding_models,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
await self.update_provider_last_synced_async(provider.id)
|
||||
await self.update_provider_last_synced_async(provider.id, actor=actor)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")
|
||||
|
||||
@@ -167,9 +167,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_sandbox_config_by_type_async(
|
||||
self, type: SandboxType, actor: Optional[PydanticUser] = None
|
||||
) -> Optional[PydanticSandboxConfig]:
|
||||
async def get_sandbox_config_by_type_async(self, type: SandboxType, actor: PydanticUser) -> Optional[PydanticSandboxConfig]:
|
||||
"""Retrieve a sandbox config by its type."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
@@ -345,7 +343,7 @@ class SandboxConfigManager:
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
@trace_method
|
||||
async def get_sandbox_env_var_by_key_and_sandbox_config_id_async(
|
||||
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
|
||||
self, key: str, sandbox_config_id: str, actor: PydanticUser
|
||||
) -> Optional[PydanticEnvVar]:
|
||||
"""Retrieve a sandbox environment variable by its key and sandbox_config_id."""
|
||||
async with db_registry.async_session() as session:
|
||||
|
||||
@@ -448,11 +448,10 @@ class SourceManager:
|
||||
|
||||
return list(agent_ids)
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
@trace_method
|
||||
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
|
||||
async def get_source_by_id(self, source_id: str, actor: PydanticUser) -> Optional[PydanticSource]:
|
||||
"""Retrieve a source by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
|
||||
|
||||
@@ -2632,7 +2632,7 @@ async def test_byok_provider_last_synced_skips_sync_when_set(default_user, provi
|
||||
)
|
||||
|
||||
# Set last_synced to indicate models are already synced
|
||||
await provider_manager.update_provider_last_synced_async(byok_provider.id)
|
||||
await provider_manager.update_provider_last_synced_async(byok_provider.id, actor=default_user)
|
||||
|
||||
# Create server
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
@@ -2692,7 +2692,7 @@ async def test_base_provider_updates_last_synced_on_sync(default_user, provider_
|
||||
embedding_models=[],
|
||||
organization_id=None,
|
||||
)
|
||||
await provider_manager.update_provider_last_synced_async(base_provider.id)
|
||||
await provider_manager.update_provider_last_synced_async(base_provider.id, actor=default_user)
|
||||
|
||||
# Verify last_synced was updated
|
||||
updated_providers = await provider_manager.list_providers_async(name=base_provider.name, actor=default_user)
|
||||
|
||||
Reference in New Issue
Block a user