feat: move more decrypt callsites outside db session (#8342)
This commit is contained in:
@@ -111,7 +111,13 @@ from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import DatabaseChoice, model_settings, settings
|
||||
from letta.utils import bounded_gather, calculate_file_defaults_based_on_context_window, enforce_types, united_diff
|
||||
from letta.utils import (
|
||||
bounded_gather,
|
||||
calculate_file_defaults_based_on_context_window,
|
||||
decrypt_agent_secrets,
|
||||
enforce_types,
|
||||
united_diff,
|
||||
)
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -884,7 +890,11 @@ class AgentManager:
|
||||
await session.flush()
|
||||
await session.refresh(agent)
|
||||
|
||||
return await agent.to_pydantic_async()
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -911,34 +921,6 @@ class AgentManager:
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
async def _decrypt_agent_secrets(self, agents: List[PydanticAgentState]) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Decrypt secrets for all agents outside DB session.
|
||||
|
||||
This allows DB connections to be released before expensive PBKDF2 operations,
|
||||
preventing connection pool exhaustion during high load.
|
||||
|
||||
Uses bounded concurrency to limit thread pool pressure while allowing some
|
||||
parallelism in the dedicated crypto executor.
|
||||
"""
|
||||
|
||||
async def decrypt_env_var(env_var):
|
||||
if env_var.value_enc and (env_var.value is None or env_var.value == ""):
|
||||
env_var.value = await env_var.value_enc.get_plaintext_async()
|
||||
|
||||
# Collect all env vars that need decryption
|
||||
decrypt_tasks = []
|
||||
for agent in agents:
|
||||
if agent.tool_exec_environment_variables:
|
||||
for env_var in agent.tool_exec_environment_variables:
|
||||
decrypt_tasks.append(decrypt_env_var(env_var))
|
||||
|
||||
# Decrypt with bounded concurrency (matches crypto executor size)
|
||||
if decrypt_tasks:
|
||||
await bounded_gather(decrypt_tasks, max_concurrency=8)
|
||||
|
||||
return agents
|
||||
|
||||
@trace_method
|
||||
async def list_agents_async(
|
||||
self,
|
||||
@@ -1015,7 +997,7 @@ class AgentManager:
|
||||
)
|
||||
|
||||
# DB session released - now decrypt secrets outside session to prevent connection holding
|
||||
return await self._decrypt_agent_secrets(agents_encrypted)
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@trace_method
|
||||
async def count_agents_async(
|
||||
@@ -1111,7 +1093,12 @@ class AgentManager:
|
||||
|
||||
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
|
||||
result = await session.execute(query)
|
||||
return await bounded_gather([agent.to_pydantic_async() for agent in result.scalars()])
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather([agent.to_pydantic_async(decrypt=False) for agent in result.scalars()])
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@trace_method
|
||||
async def size_async(
|
||||
@@ -1136,8 +1123,8 @@ class AgentManager:
|
||||
) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
try:
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(AgentModel)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
query = query.where(AgentModel.id == agent_id)
|
||||
@@ -1149,13 +1136,17 @@ class AgentManager:
|
||||
if agent is None:
|
||||
raise NoResultFound(f"Agent with ID {agent_id} not found")
|
||||
|
||||
return await agent.to_pydantic_async(include_relationships=include_relationships, include=include)
|
||||
except NoResultFound:
|
||||
# Re-raise NoResultFound without logging to preserve 404 handling
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
|
||||
raise
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(include_relationships=include_relationships, include=include, decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
except NoResultFound:
|
||||
# Re-raise NoResultFound without logging to preserve 404 handling
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -1166,8 +1157,8 @@ class AgentManager:
|
||||
include_relationships: Optional[List[str]] = None,
|
||||
) -> list[PydanticAgentState]:
|
||||
"""Fetch a list of agents by their IDs."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
try:
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(AgentModel)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
query = query.where(AgentModel.id.in_(agent_ids))
|
||||
@@ -1180,10 +1171,16 @@ class AgentManager:
|
||||
logger.warning(f"No agents found with IDs: {agent_ids}")
|
||||
return []
|
||||
|
||||
return await bounded_gather([agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
|
||||
raise
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=include_relationships, decrypt=False) for agent in agents]
|
||||
)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -1739,7 +1736,12 @@ class AgentManager:
|
||||
agent = await agent.update_async(session, actor=actor)
|
||||
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
|
||||
# or even better, not refresh at all.
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -1901,7 +1903,12 @@ class AgentManager:
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
# TODO: This refresh is expensive. If we can find out which fields are needed, we can save cost by only refreshing those fields.
|
||||
# or even better, not refresh at all.
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block management
|
||||
@@ -1989,7 +1996,11 @@ class AgentManager:
|
||||
# TODO: I have too many things rn so lets look at this later
|
||||
# await session.commit()
|
||||
|
||||
return await agent.to_pydantic_async()
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -2010,7 +2021,12 @@ class AgentManager:
|
||||
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||||
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agent_encrypted = await agent.to_pydantic_async(decrypt=False)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return (await decrypt_agent_secrets([agent_encrypted]))[0]
|
||||
|
||||
# ======================================================================================================================
|
||||
# Passage Management
|
||||
|
||||
@@ -18,7 +18,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import bounded_gather, enforce_types
|
||||
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -554,8 +554,13 @@ class ArchiveManager:
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
agents = await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents_orm])
|
||||
return agents
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
|
||||
)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -19,7 +19,7 @@ from letta.schemas.enums import ActorType, PrimitiveType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import bounded_gather, enforce_types
|
||||
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -505,8 +505,13 @@ class BlockManager:
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
agents = await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents_orm])
|
||||
return agents
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents_orm]
|
||||
)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -24,7 +24,7 @@ from letta.schemas.identity import (
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import bounded_gather, enforce_types
|
||||
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
|
||||
@@ -336,7 +336,14 @@ class IdentityManager:
|
||||
ascending=ascending,
|
||||
identity_id=identity.id,
|
||||
)
|
||||
return await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents])
|
||||
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=[], include=include, decrypt=False) for agent in agents]
|
||||
)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
|
||||
@@ -15,7 +15,7 @@ from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
||||
from letta.schemas.source import Source as PydanticSource, SourceUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import bounded_gather, enforce_types, printd
|
||||
from letta.utils import bounded_gather, decrypt_agent_secrets, enforce_types, printd
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
|
||||
@@ -326,7 +326,13 @@ class SourceManager:
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
return await bounded_gather([agent.to_pydantic_async(include_relationships=[], include=[]) for agent in agents_orm])
|
||||
# Convert without decrypting to release DB connection before PBKDF2
|
||||
agents_encrypted = await bounded_gather(
|
||||
[agent.to_pydantic_async(include_relationships=[], include=[], decrypt=False) for agent in agents_orm]
|
||||
)
|
||||
|
||||
# Decrypt secrets outside session
|
||||
return await decrypt_agent_secrets(agents_encrypted)
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
|
||||
@@ -1471,3 +1471,38 @@ async def bounded_gather(coros: list[Coroutine], max_concurrency: int = 10) -> l
|
||||
# Sort by original index to preserve order
|
||||
indexed_results.sort(key=lambda x: x[0])
|
||||
return [result for _, result in indexed_results]
|
||||
|
||||
|
||||
async def decrypt_agent_secrets(agents: list) -> list:
|
||||
"""
|
||||
Decrypt secrets for all agents outside DB session.
|
||||
|
||||
This allows DB connections to be released before expensive PBKDF2 operations,
|
||||
preventing connection pool exhaustion during high load.
|
||||
|
||||
Uses bounded concurrency to limit thread pool pressure while allowing some
|
||||
parallelism in the dedicated crypto executor.
|
||||
|
||||
Args:
|
||||
agents: List of PydanticAgentState objects with encrypted secrets
|
||||
|
||||
Returns:
|
||||
Same list with secrets decrypted
|
||||
"""
|
||||
|
||||
async def decrypt_env_var(env_var):
|
||||
if env_var.value_enc and (env_var.value is None or env_var.value == ""):
|
||||
env_var.value = await env_var.value_enc.get_plaintext_async()
|
||||
|
||||
# Collect all env vars that need decryption
|
||||
decrypt_tasks = []
|
||||
for agent in agents:
|
||||
if agent.tool_exec_environment_variables:
|
||||
for env_var in agent.tool_exec_environment_variables:
|
||||
decrypt_tasks.append(decrypt_env_var(env_var))
|
||||
|
||||
# Decrypt with bounded concurrency (matches crypto executor size)
|
||||
if decrypt_tasks:
|
||||
await bounded_gather(decrypt_tasks, max_concurrency=8)
|
||||
|
||||
return agents
|
||||
|
||||
Reference in New Issue
Block a user