From 51a68d12508dfd17ef3083176a2b9586be48aba2 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 6 Jan 2026 12:27:26 -0800 Subject: [PATCH] feat: move more decrypt callsites outside db session (#8342) --- letta/services/agent_manager.py | 118 ++++++++++++++++------------- letta/services/archive_manager.py | 11 ++- letta/services/block_manager.py | 11 ++- letta/services/identity_manager.py | 11 ++- letta/services/source_manager.py | 10 ++- letta/utils.py | 35 +++++++++ 6 files changed, 135 insertions(+), 61 deletions(-) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d32541fd..10cd123d 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -111,7 +111,13 @@ from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import DatabaseChoice, model_settings, settings -from letta.utils import bounded_gather, calculate_file_defaults_based_on_context_window, enforce_types, united_diff +from letta.utils import ( + bounded_gather, + calculate_file_defaults_based_on_context_window, + decrypt_agent_secrets, + enforce_types, + united_diff, +) from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -884,7 +890,11 @@ class AgentManager: await session.flush() await session.refresh(agent) - return await agent.to_pydantic_async() + # Convert without decrypting to release DB connection before PBKDF2 + agent_encrypted = await agent.to_pydantic_async(decrypt=False) + + # Decrypt secrets outside session + return (await decrypt_agent_secrets([agent_encrypted]))[0] @enforce_types @trace_method @@ -911,34 +921,6 @@ class AgentManager: # context manager now handles commits # await session.commit() - async def _decrypt_agent_secrets(self, agents: List[PydanticAgentState]) -> List[PydanticAgentState]: - """ - Decrypt secrets for all agents outside DB session. - - This allows DB connections to be released before expensive PBKDF2 operations, - preventing connection pool exhaustion during high load. - - Uses bounded concurrency to limit thread pool pressure while allowing some - parallelism in the dedicated crypto executor. - """ - - async def decrypt_env_var(env_var): - if env_var.value_enc and (env_var.value is None or env_var.value == ""): - env_var.value = await env_var.value_enc.get_plaintext_async() - - # Collect all env vars that need decryption - decrypt_tasks = [] - for agent in agents: - if agent.tool_exec_environment_variables: - for env_var in agent.tool_exec_environment_variables: - decrypt_tasks.append(decrypt_env_var(env_var)) - - # Decrypt with bounded concurrency (matches crypto executor size) - if decrypt_tasks: - await bounded_gather(decrypt_tasks, max_concurrency=8) - - return agents - @trace_method async def list_agents_async( self, @@ -1015,7 +997,7 @@ class AgentManager: ) # DB session released - now decrypt secrets outside session to prevent connection holding - return await self._decrypt_agent_secrets(agents_encrypted) + return await decrypt_agent_secrets(agents_encrypted) @trace_method async def count_agents_async( @@ -1111,7 +1093,12 @@ class AgentManager: query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit) result = await session.execute(query) - return await bounded_gather([agent.to_pydantic_async() for agent in result.scalars()]) + + # Convert without decrypting to release DB connection before PBKDF2 + agents_encrypted = await bounded_gather([agent.to_pydantic_async(decrypt=False) for agent in result.scalars()]) + + # Decrypt secrets outside session + return await decrypt_agent_secrets(agents_encrypted) @trace_method async def size_async( @@ -1136,8 +1123,8 @@ class AgentManager: ) -> PydanticAgentState: """Fetch an agent by its ID.""" - async with db_registry.async_session() as session: - try: + try: + async with db_registry.async_session() as session: query = select(AgentModel) query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) query = query.where(AgentModel.id == agent_id) @@ -1149,13 +1136,17 @@ class AgentManager: if agent is None: raise NoResultFound(f"Agent with ID {agent_id} not found") - return await agent.to_pydantic_async(include_relationships=include_relationships, include=include) - except NoResultFound: - # Re-raise NoResultFound without logging to preserve 404 handling - raise - except Exception as e: - logger.error(f"Error fetching agent {agent_id}: {str(e)}") - raise + # Convert without decrypting to release DB connection before PBKDF2 + agent_encrypted = await agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False) + + # Decrypt secrets outside session + return (await decrypt_agent_secrets([agent_encrypted]))[0] + except NoResultFound: + # Re-raise NoResultFound without logging to preserve 404 handling + raise + except Exception as e: + logger.error(f"Error fetching agent {agent_id}: {str(e)}") + raise @enforce_types @trace_method @@ -1166,8 +1157,8 @@ class AgentManager: include_relationships: Optional[List[str]] = None, ) -> list[PydanticAgentState]: """Fetch a list of agents by their IDs.""" - async with db_registry.async_session() as session: - try: + try: + async with db_registry.async_session() as session: query = select(AgentModel) query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) query = query.where(AgentModel.id.in_(agent_ids)) @@ -1180,10 +1171,16 @@ class AgentManager: logger.warning(f"No agents found with IDs: {agent_ids}") return [] - return await bounded_gather([agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) - except Exception as e: - logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}") - raise + # Convert without decrypting to release DB connection before PBKDF2 + agents_encrypted = await bounded_gather( + [agent.to_pydantic_async(include_relationships=include_relationships, decrypt=False) for agent in agents] + ) + + # Decrypt secrets outside session + return await decrypt_agent_secrets(agents_encrypted) + except Exception as e: + logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}") + raise @enforce_types @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @@ -1739,7 +1736,12 @@ class AgentManager: agent = await agent.update_async(session, actor=actor) # TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields. # or even better, not refresh at all. - return await agent.to_pydantic_async() + + # Convert without decrypting to release DB connection before PBKDF2 + agent_encrypted = await agent.to_pydantic_async(decrypt=False) + + # Decrypt secrets outside session + return (await decrypt_agent_secrets([agent_encrypted]))[0] @enforce_types @trace_method @@ -1901,7 +1903,12 @@ class AgentManager: agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) # TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields. # or even better, not refresh at all. - return await agent.to_pydantic_async() + + # Convert without decrypting to release DB connection before PBKDF2 + agent_encrypted = await agent.to_pydantic_async(decrypt=False) + + # Decrypt secrets outside session + return (await decrypt_agent_secrets([agent_encrypted]))[0] # ====================================================================================================================== # Block management @@ -1989,7 +1996,11 @@ class AgentManager: # TODO: I have too many things rn so lets look at this later # await session.commit() - return await agent.to_pydantic_async() + # Convert without decrypting to release DB connection before PBKDF2 + agent_encrypted = await agent.to_pydantic_async(decrypt=False) + + # Decrypt secrets outside session + return (await decrypt_agent_secrets([agent_encrypted]))[0] @enforce_types @trace_method @@ -2010,7 +2021,12 @@ class AgentManager: raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'") await agent.update_async(session, actor=actor) - return await agent.to_pydantic_async() + + # Convert without decrypting to release DB connection before PBKDF2 + agent_encrypted = await agent.to_pydantic_async(decrypt=False) + + # Decrypt secrets outside session + return (await decrypt_agent_secrets([agent_encrypted]))[0] # ====================================================================================================================== # Passage Management diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index 52b5843c..b982ca4d 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -18,7 +18,7 @@ from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.helpers.agent_manager_helper import validate_agent_exists_async from letta.settings import DatabaseChoice, settings -from letta.utils import bounded_gather, enforce_types +from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -554,8 +554,13 @@ class ArchiveManager: result = await session.execute(query) agents_orm = result.scalars().all() - agents = await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents_orm]) - return agents + # Convert without decrypting to release DB connection before PBKDF2 + agents_encrypted = await bounded_gather( + [agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm] + ) + + # Decrypt secrets outside session + return await decrypt_agent_secrets(agents_encrypted) @enforce_types @trace_method diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index ebfc03eb..5229ba1e 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -19,7 +19,7 @@ from letta.schemas.enums import ActorType, PrimitiveType from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.settings import DatabaseChoice, settings -from letta.utils import bounded_gather, enforce_types +from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -505,8 +505,13 @@ class BlockManager: result = await session.execute(query) agents_orm = result.scalars().all() - agents = await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents_orm]) - return agents + # Convert without decrypting to release DB connection before PBKDF2 + agents_encrypted = await bounded_gather( + [agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm] + ) + + # Decrypt secrets outside session + return await decrypt_agent_secrets(agents_encrypted) @enforce_types @trace_method diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 6fa9e5bc..fc416094 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -24,7 +24,7 @@ from letta.schemas.identity import ( from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.settings import DatabaseChoice, settings -from letta.utils import bounded_gather, enforce_types +from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types from letta.validators import raise_on_invalid_id @@ -336,7 +336,14 @@ class IdentityManager: ascending=ascending, identity_id=identity.id, ) - return await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents]) + + # Convert without decrypting to release DB connection before PBKDF2 + agents_encrypted = await bounded_gather( + [agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents] + ) + + # Decrypt secrets outside session + return await decrypt_agent_secrets(agents_encrypted) @enforce_types @raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY) diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 123bfc3e..b45c9128 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -15,7 +15,7 @@ from letta.schemas.enums import PrimitiveType, VectorDBProvider from letta.schemas.source import Source as PydanticSource, SourceUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry -from letta.utils import bounded_gather, enforce_types, printd +from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types, printd from letta.validators import raise_on_invalid_id @@ -326,7 +326,13 @@ class SourceManager: result = await session.execute(query) agents_orm = result.scalars().all() - return await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=[]) for agent in agents_orm]) + # Convert without decrypting to release DB connection before PBKDF2 + agents_encrypted = await bounded_gather( + [agent.to_pydantic_async(include_relationships=[], include=[], decrypt=False) for agent in agents_orm] + ) + + # Decrypt secrets outside session + return await decrypt_agent_secrets(agents_encrypted) @enforce_types @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) diff --git a/letta/utils.py b/letta/utils.py index f360d8f3..d34546d5 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -1471,3 +1471,38 @@ async def bounded_gather(coros: list[Coroutine], max_concurrency: int = 10) -> l # Sort by original index to preserve order indexed_results.sort(key=lambda x: x[0]) return [result for _, result in indexed_results] + + +async def decrypt_agent_secrets(agents: list) -> list: + """ + Decrypt secrets for all agents outside DB session. + + This allows DB connections to be released before expensive PBKDF2 operations, + preventing connection pool exhaustion during high load. + + Uses bounded concurrency to limit thread pool pressure while allowing some + parallelism in the dedicated crypto executor. + + Args: + agents: List of PydanticAgentState objects with encrypted secrets + + Returns: + Same list with secrets decrypted + """ + + async def decrypt_env_var(env_var): + if env_var.value_enc and (env_var.value is None or env_var.value == ""): + env_var.value = await env_var.value_enc.get_plaintext_async() + + # Collect all env vars that need decryption + decrypt_tasks = [] + for agent in agents: + if agent.tool_exec_environment_variables: + for env_var in agent.tool_exec_environment_variables: + decrypt_tasks.append(decrypt_env_var(env_var)) + + # Decrypt with bounded concurrency (matches crypto executor size) + if decrypt_tasks: + await bounded_gather(decrypt_tasks, max_concurrency=8) + + return agents