feat: move more decrypt callsites outside db session (#8342)

This commit is contained in:
cthomas
2026-01-06 12:27:26 -08:00
committed by Caren Thomas
parent ccfd3d1432
commit 51a68d1250
6 changed files with 135 additions and 61 deletions

View File

@@ -111,7 +111,13 @@ from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.settings import DatabaseChoice, model_settings, settings
from letta.utils import bounded_gather, calculate_file_defaults_based_on_context_window, enforce_types, united_diff
from letta.utils import (
bounded_gather,
calculate_file_defaults_based_on_context_window,
decrypt_agent_secrets,
enforce_types,
united_diff,
)
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -884,7 +890,11 @@ class AgentManager:
await session.flush()
await session.refresh(agent)
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
@enforce_types
@trace_method
@@ -911,34 +921,6 @@ class AgentManager:
# context manager now handles commits
# await session.commit()
async def _decrypt_agent_secrets(self, agents: List[PydanticAgentState]) -> List[PydanticAgentState]:
"""
Decrypt secrets for all agents outside DB session.
This allows DB connections to be released before expensive PBKDF2 operations,
preventing connection pool exhaustion during high load.
Uses bounded concurrency to limit thread pool pressure while allowing some
parallelism in the dedicated crypto executor.
"""
async def decrypt_env_var(env_var):
if env_var.value_enc and (env_var.value is None or env_var.value == ""):
env_var.value = await env_var.value_enc.get_plaintext_async()
# Collect all env vars that need decryption
decrypt_tasks = []
for agent in agents:
if agent.tool_exec_environment_variables:
for env_var in agent.tool_exec_environment_variables:
decrypt_tasks.append(decrypt_env_var(env_var))
# Decrypt with bounded concurrency (matches crypto executor size)
if decrypt_tasks:
await bounded_gather(decrypt_tasks, max_concurrency=8)
return agents
@trace_method
async def list_agents_async(
self,
@@ -1015,7 +997,7 @@ class AgentManager:
)
# DB session released - now decrypt secrets outside session to prevent connection holding
return await self._decrypt_agent_secrets(agents_encrypted)
return await decrypt_agent_secrets(agents_encrypted)
@trace_method
async def count_agents_async(
@@ -1111,7 +1093,12 @@ class AgentManager:
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
result = await session.execute(query)
return await bounded_gather([agent.to_pydantic_async() for agent in result.scalars()])
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather([agent.to_pydantic_async(decrypt=False) for agent in result.scalars()])
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@trace_method
async def size_async(
@@ -1136,8 +1123,8 @@ class AgentManager:
) -> PydanticAgentState:
"""Fetch an agent by its ID."""
async with db_registry.async_session() as session:
try:
try:
async with db_registry.async_session() as session:
query = select(AgentModel)
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
query = query.where(AgentModel.id == agent_id)
@@ -1149,13 +1136,17 @@ class AgentManager:
if agent is None:
raise NoResultFound(f"Agent with ID {agent_id} not found")
return await agent.to_pydantic_async(include_relationships=include_relationships, include=include)
except NoResultFound:
# Re-raise NoResultFound without logging to preserve 404 handling
raise
except Exception as e:
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
raise
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
except NoResultFound:
# Re-raise NoResultFound without logging to preserve 404 handling
raise
except Exception as e:
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
raise
@enforce_types
@trace_method
@@ -1166,8 +1157,8 @@ class AgentManager:
include_relationships: Optional[List[str]] = None,
) -> list[PydanticAgentState]:
"""Fetch a list of agents by their IDs."""
async with db_registry.async_session() as session:
try:
try:
async with db_registry.async_session() as session:
query = select(AgentModel)
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
query = query.where(AgentModel.id.in_(agent_ids))
@@ -1180,10 +1171,16 @@ class AgentManager:
logger.warning(f"No agents found with IDs: {agent_ids}")
return []
return await bounded_gather([agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
except Exception as e:
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
raise
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=include_relationships, decrypt=False) for agent in agents]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
except Exception as e:
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
raise
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -1739,7 +1736,12 @@ class AgentManager:
agent = await agent.update_async(session, actor=actor)
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
# or even better, not refresh at all.
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
@enforce_types
@trace_method
@@ -1901,7 +1903,12 @@ class AgentManager:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
# or even better, not refresh at all.
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
# ======================================================================================================================
# Block management
@@ -1989,7 +1996,11 @@ class AgentManager:
# TODO: I have too many things rn so lets look at this later
# await session.commit()
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
@enforce_types
@trace_method
@@ -2010,7 +2021,12 @@ class AgentManager:
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
# Convert without decrypting to release DB connection before PBKDF2
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
# Decrypt secrets outside session
return (await decrypt_agent_secrets([agent_encrypted]))[0]
# ======================================================================================================================
# Passage Management

View File

@@ -18,7 +18,7 @@ from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
from letta.settings import DatabaseChoice, settings
from letta.utils import bounded_gather, enforce_types
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -554,8 +554,13 @@ class ArchiveManager:
result = await session.execute(query)
agents_orm = result.scalars().all()
agents = await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents_orm])
return agents
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@enforce_types
@trace_method

View File

@@ -19,7 +19,7 @@ from letta.schemas.enums import ActorType, PrimitiveType
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import DatabaseChoice, settings
from letta.utils import bounded_gather, enforce_types
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -505,8 +505,13 @@ class BlockManager:
result = await session.execute(query)
agents_orm = result.scalars().all()
agents = await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents_orm])
return agents
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@enforce_types
@trace_method

View File

@@ -24,7 +24,7 @@ from letta.schemas.identity import (
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import DatabaseChoice, settings
from letta.utils import bounded_gather, enforce_types
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
from letta.validators import raise_on_invalid_id
@@ -336,7 +336,14 @@ class IdentityManager:
ascending=ascending,
identity_id=identity.id,
)
return await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents])
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@enforce_types
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)

View File

@@ -15,7 +15,7 @@ from letta.schemas.enums import PrimitiveType, VectorDBProvider
from letta.schemas.source import Source as PydanticSource, SourceUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import bounded_gather, enforce_types, printd
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types, printd
from letta.validators import raise_on_invalid_id
@@ -326,7 +326,13 @@ class SourceManager:
result = await session.execute(query)
agents_orm = result.scalars().all()
return await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=[]) for agent in agents_orm])
# Convert without decrypting to release DB connection before PBKDF2
agents_encrypted = await bounded_gather(
[agent.to_pydantic_async(include_relationships=[], include=[], decrypt=False) for agent in agents_orm]
)
# Decrypt secrets outside session
return await decrypt_agent_secrets(agents_encrypted)
@enforce_types
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)

View File

@@ -1471,3 +1471,38 @@ async def bounded_gather(coros: list[Coroutine], max_concurrency: int = 10) -> l
# Sort by original index to preserve order
indexed_results.sort(key=lambda x: x[0])
return [result for _, result in indexed_results]
async def decrypt_agent_secrets(agents: list) -> list:
"""
Decrypt secrets for all agents outside DB session.
This allows DB connections to be released before expensive PBKDF2 operations,
preventing connection pool exhaustion during high load.
Uses bounded concurrency to limit thread pool pressure while allowing some
parallelism in the dedicated crypto executor.
Args:
agents: List of PydanticAgentState objects with encrypted secrets
Returns:
Same list with secrets decrypted
"""
async def decrypt_env_var(env_var):
if env_var.value_enc and (env_var.value is None or env_var.value == ""):
env_var.value = await env_var.value_enc.get_plaintext_async()
# Collect all env vars that need decryption
decrypt_tasks = []
for agent in agents:
if agent.tool_exec_environment_variables:
for env_var in agent.tool_exec_environment_variables:
decrypt_tasks.append(decrypt_env_var(env_var))
# Decrypt with bounded concurrency (matches crypto executor size)
if decrypt_tasks:
await bounded_gather(decrypt_tasks, max_concurrency=8)
return agents