fix: duplicate session commit in step logging (#7512)

* fix: duplicate session commit in step logging

* update all callsites
This commit is contained in:
cthomas
2025-12-18 17:23:09 -08:00
committed by Caren Thomas
parent 4d8d9757aa
commit 9a95a8f976
27 changed files with 148 additions and 74 deletions

View File

@@ -899,7 +899,8 @@ class AgentManager:
agent.message_ids = message_ids
await agent.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
# context manager now handles commits
# await session.commit()
@trace_method
async def list_agents_async(
@@ -1216,7 +1217,8 @@ class AgentManager:
await session.commit()
for agent in agents_to_delete:
await session.delete(agent)
await session.commit()
# context manager now handles commits
# await session.commit()
except Exception as e:
await session.rollback()
logger.exception(f"Failed to hard delete Agent with ID {agent_id}")
@@ -2570,7 +2572,8 @@ class AgentManager:
agent.tool_rules = tool_rules
session.add(agent)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -2643,7 +2646,8 @@ class AgentManager:
else:
logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}")
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -2767,7 +2771,8 @@ class AgentManager:
else:
logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}")
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -2804,7 +2809,8 @@ class AgentManager:
else:
logger.info(f"Detached all {detached_count} tools from agent {agent_id}")
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -2832,7 +2838,8 @@ class AgentManager:
agent.tool_rules = tool_rules
session.add(agent)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method

View File

@@ -191,7 +191,8 @@ class ArchiveManager:
is_owner=is_owner,
)
session.add(archives_agents)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -224,7 +225,8 @@ class ArchiveManager:
else:
logger.info(f"Detached agent {agent_id} from archive {archive_id}")
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@@ -609,6 +611,7 @@ class ArchiveManager:
# update the archive with the namespace
await session.execute(update(ArchiveModel).where(ArchiveModel.id == archive_id).values(_vector_db_namespace=namespace_name))
await session.commit()
# context manager now handles commits
# await session.commit()
return namespace_name

View File

@@ -100,7 +100,8 @@ class BlockManager:
block = BlockModel(**data, organization_id=actor.organization_id)
await block.create_async(session, actor=actor, no_commit=True, no_refresh=True)
pydantic_block = block.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_block
@enforce_types
@@ -130,7 +131,8 @@ class BlockManager:
items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True
)
result = [m.to_pydantic() for m in created_models]
await session.commit()
# context manager now handles commits
# await session.commit()
return result
@enforce_types
@@ -150,7 +152,8 @@ class BlockManager:
await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
pydantic_block = block.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_block
@enforce_types
@@ -591,7 +594,8 @@ class BlockManager:
new_val = new_val[: block.limit]
block.value = new_val
await session.commit()
# context manager now handles commits
# await session.commit()
if return_hydrated:
# TODO: implement for async
@@ -669,7 +673,8 @@ class BlockManager:
# 7) Flush changes, then commit once
block = await block.update_async(db_session=session, actor=actor, no_commit=True)
await session.commit()
# context manager now handles commits
# await session.commit()
return block.to_pydantic()
@@ -757,7 +762,8 @@ class BlockManager:
block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor)
# 4) Commit
await session.commit()
# context manager now handles commits
# await session.commit()
return block.to_pydantic()
@enforce_types
@@ -805,5 +811,6 @@ class BlockManager:
block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor)
await session.commit()
# context manager now handles commits
# await session.commit()
return block.to_pydantic()

View File

@@ -200,7 +200,8 @@ class FileAgentManager:
stmt = delete(FileAgentModel).where(and_(or_(*conditions), FileAgentModel.organization_id == actor.organization_id))
result = await session.execute(stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
return result.rowcount
@@ -405,7 +406,8 @@ class FileAgentManager:
.values(last_accessed_at=func.now())
)
await session.execute(stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -425,7 +427,8 @@ class FileAgentManager:
.values(last_accessed_at=func.now())
)
await session.execute(stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -458,7 +461,8 @@ class FileAgentManager:
)
closed_file_names = [row.file_name for row in (await session.execute(stmt))]
await session.commit()
# context manager now handles commits
# await session.commit()
return closed_file_names
@enforce_types
@@ -702,7 +706,8 @@ class FileAgentManager:
.values(is_open=False, visible_content=None)
)
await session.commit()
# context manager now handles commits
# await session.commit()
return closed_file_names
async def _get_association_by_file_id(self, session, agent_id: str, file_id: str, actor: PydanticUser) -> FileAgentModel:

View File

@@ -246,7 +246,8 @@ class GroupManager:
)
await session.execute(delete_stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
@@ -434,7 +435,8 @@ class GroupManager:
# Add block to group
session.add(GroupsBlocks(group_id=group_id, block_id=block_id))
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
@@ -452,7 +454,8 @@ class GroupManager:
# Remove block from group
delete_group_block = delete(GroupsBlocks).where(and_(GroupsBlocks.group_id == group_id, GroupsBlocks.block_id == block_id))
await session.execute(delete_group_block)
await session.commit()
# context manager now handles commits
# await session.commit()
@staticmethod
def ensure_buffer_length_range_valid(

View File

@@ -257,7 +257,8 @@ class IdentityManager:
if identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
await session.delete(identity)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method

View File

@@ -58,7 +58,8 @@ class JobManager:
job.organization_id = actor.organization_id
job = await job.create_async(session, actor=actor, no_commit=True, no_refresh=True) # Save job in the database
await session.commit()
# context manager now handles commits
# await session.commit()
# Convert to pydantic first, then add agent_id if needed
result = super(JobModel, job).to_pydantic()
@@ -122,7 +123,8 @@ class JobManager:
# Get the updated metadata for callback
final_metadata = job.metadata_
result = job.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
# Dispatch callback outside of database session if needed
if needs_callback:
@@ -143,7 +145,8 @@ class JobManager:
job.callback_error = callback_result.get("callback_error")
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
result = job.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return result
@@ -462,7 +465,8 @@ class JobManager:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.ttft_ns = ttft_ns
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
# context manager now handles commits
# await session.commit()
except Exception as e:
logger.warning(f"Failed to record TTFT for job {job_id}: {e}")
@@ -475,7 +479,8 @@ class JobManager:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.total_duration_ns = total_duration_ns
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
# context manager now handles commits
# await session.commit()
except Exception as e:
logger.warning(f"Failed to record response duration for job {job_id}: {e}")

View File

@@ -45,7 +45,8 @@ class LLMBatchManager:
)
await batch.create_async(session, actor=actor, no_commit=True, no_refresh=True)
pydantic_batch = batch.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_batch
@enforce_types
@@ -98,7 +99,8 @@ class LLMBatchManager:
)
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings))
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method
@@ -285,7 +287,8 @@ class LLMBatchManager:
created_items = await LLMBatchItem.batch_create_async(orm_items, session, actor=actor, no_commit=True, no_refresh=True)
pydantic_items = [item.to_pydantic() for item in created_items]
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_items
@enforce_types
@@ -421,7 +424,8 @@ class LLMBatchManager:
if mappings:
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings))
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method

View File

@@ -257,7 +257,8 @@ class MCPManager:
logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}")
# Commit deletions
await session.commit()
# context manager now handles commits
# await session.commit()
# 2. Update existing tools and add new tools
for tool_name, current_tool in current_tool_map.items():
@@ -394,7 +395,8 @@ class MCPManager:
f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}"
)
await session.commit()
# context manager now handles commits
# await session.commit()
return mcp_server.to_pydantic()
except Exception as e:
await session.rollback()
@@ -700,7 +702,8 @@ class MCPManager:
)
)
await session.commit()
# context manager now handles commits
# await session.commit()
except NoResultFound:
await session.rollback()
raise ValueError(f"MCP server with id {mcp_server_id} not found.")

View File

@@ -82,7 +82,8 @@ class MCPServerManager:
MCPToolsModel.organization_id == actor.organization_id,
)
)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
async def get_tool_ids_by_mcp_server(self, mcp_server_id: str, actor: PydanticUser) -> List[str]:
@@ -348,7 +349,8 @@ class MCPServerManager:
logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}")
# Commit deletions
await session.commit()
# context manager now handles commits
# await session.commit()
# 2. Update existing tools and add new tools
for tool_name, current_tool in current_tool_map.items():
@@ -489,7 +491,8 @@ class MCPServerManager:
f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}"
)
await session.commit()
# context manager now handles commits
# await session.commit()
return mcp_server.to_pydantic()
except Exception as e:
await session.rollback()
@@ -871,7 +874,8 @@ class MCPServerManager:
)
)
await session.commit()
# context manager now handles commits
# await session.commit()
except NoResultFound:
await session.rollback()
raise ValueError(f"MCP server with id {mcp_server_id} not found.")

View File

@@ -515,7 +515,8 @@ class MessageManager:
async with db_registry.async_session() as session:
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True)
result = [msg.to_pydantic() for msg in created_messages]
await session.commit()
# context manager now handles commits
# await session.commit()
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
@@ -673,7 +674,8 @@ class MessageManager:
message = self._update_message_by_id_impl(message_id, message_update, actor, message)
await message.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
pydantic_message = message.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
@@ -979,7 +981,8 @@ class MessageManager:
rowcount = result.rowcount
# 4) commit once
await session.commit()
# context manager now handles commits
# await session.commit()
# 5) delete from turbopuffer if enabled (outside of DB session)
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
@@ -1031,7 +1034,8 @@ class MessageManager:
rowcount = result.rowcount
# commit once
await session.commit()
# context manager now handles commits
# await session.commit()
if should_use_tpuf_for_messages() and agent_ids:
try:

View File

@@ -174,7 +174,8 @@ class ProviderManager:
# Soft delete in provider table
await existing_provider.delete_async(session, actor=actor)
await session.commit()
# context manager now handles commits
# await session.commit()
@enforce_types
@trace_method

View File

@@ -88,7 +88,8 @@ class RunManager:
num_steps=0, # Initialize to 0
)
await metrics.create_async(session)
await session.commit()
# context manager now handles commits
# await session.commit()
return run.to_pydantic()
@@ -366,7 +367,8 @@ class RunManager:
final_metadata = run.metadata_
pydantic_run = run.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
# Update agent's last_stop_reason when run completes
# Do this after run update is committed to database
@@ -417,7 +419,8 @@ class RunManager:
metrics.num_steps = num_steps
metrics.tools_used = list(tools_used) if tools_used else None
await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
# context manager now handles commits
# await session.commit()
# Dispatch callback outside of database session if needed
if needs_callback:
@@ -445,7 +448,8 @@ class RunManager:
run.callback_error = callback_result.get("callback_error")
pydantic_run = run.to_pydantic()
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_run

View File

@@ -170,7 +170,8 @@ class SourceManager:
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
await session.execute(upsert_stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
# fetch results
source_names = [source.name for source in source_data_list]

View File

@@ -206,7 +206,8 @@ class StepManager:
new_step = StepModel(**step_data)
await new_step.create_async(session, no_commit=True, no_refresh=True)
pydantic_step = new_step.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_step
@enforce_types
@@ -266,7 +267,8 @@ class StepManager:
raise Exception("Unauthorized")
step.tid = transaction_id
await session.commit()
# context manager now handles commits
# await session.commit()
return step.to_pydantic()
@enforce_types
@@ -318,7 +320,8 @@ class StepManager:
raise Exception("Unauthorized")
step.stop_reason = stop_reason
await session.commit()
# context manager now handles commits
# await session.commit()
return step
@enforce_types
@@ -364,7 +367,8 @@ class StepManager:
if stop_reason:
step.stop_reason = stop_reason.stop_reason
await session.commit()
# context manager now handles commits
# await session.commit()
pydantic_step = step.to_pydantic()
# Send webhook notification for step completion outside the DB session
webhook_service = WebhookService()
@@ -415,7 +419,8 @@ class StepManager:
if usage.completion_tokens_details:
step.completion_tokens_details = usage.completion_tokens_details.model_dump()
await session.commit()
# context manager now handles commits
# await session.commit()
pydantic_step = step.to_pydantic()
# Send webhook notification for step completion outside the DB session
webhook_service = WebhookService()
@@ -455,7 +460,8 @@ class StepManager:
if stop_reason:
step.stop_reason = stop_reason.stop_reason
await session.commit()
# context manager now handles commits
# await session.commit()
pydantic_step = step.to_pydantic()
# Send webhook notification for step completion outside the DB session
webhook_service = WebhookService()

View File

@@ -36,7 +36,8 @@ class TelemetryManager:
provider_trace.response_json = json_loads(response_json_str)
await provider_trace.create_async(session, actor=actor, no_commit=True, no_refresh=True)
pydantic_provider_trace = provider_trace.to_pydantic()
await session.commit()
# context manager now handles commits
# await session.commit()
return pydantic_provider_trace

View File

@@ -1226,7 +1226,8 @@ class ToolManager:
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"])
await session.execute(upsert_stmt)
await session.commit()
# context manager now handles commits
# await session.commit()
# fetch results (includes both inserted and skipped tools)
tool_names = [tool.name for tool in tool_data_list]

View File

@@ -120,7 +120,8 @@ async def clear_tables():
async with db_registry.async_session() as session:
await session.execute(delete(SandboxEnvironmentVariable))
await session.execute(delete(SandboxConfig))
await session.commit() # Commit the deletion
# context manager now handles commits
# await session.commit()
@pytest.fixture

View File

@@ -43,7 +43,8 @@ async def _clear_tables():
if table.name == "block_history":
continue
await session.execute(table.delete()) # Truncate table
await session.commit()
# context manager now handles commits
# await session.commit()
def _run_server():

View File

@@ -85,7 +85,8 @@ async def clear_tables():
async with db_registry.async_session() as session:
await session.execute(delete(SandboxEnvironmentVariable))
await session.execute(delete(SandboxConfig))
await session.commit() # Commit the deletion
# context manager now handles commits
# await session.commit()
@pytest.fixture

View File

@@ -53,7 +53,8 @@ async def _clear_tables():
async with db_registry.async_session() as session:
for table in reversed(Base.metadata.sorted_tables):
await session.execute(table.delete())
await session.commit()
# context manager now handles commits
# await session.commit()
@pytest.fixture(autouse=True)

View File

@@ -60,7 +60,8 @@ async def clear_tables():
async with db_registry.async_session() as session:
await session.execute(delete(SandboxEnvironmentVariable))
await session.execute(delete(SandboxConfig))
await session.commit() # Commit the deletion
# context manager now handles commits
# await session.commit()
@pytest.fixture

View File

@@ -2423,7 +2423,8 @@ async def test_list_tools_with_corrupted_tool(server: SyncServer, default_user,
)
session.add(corrupted_tool)
await session.commit()
# context manager now handles commits
# await session.commit()
corrupted_tool_id = corrupted_tool.id
# Now try to list tools - it should still work and not include the corrupted tool

View File

@@ -79,7 +79,8 @@ def _clear_tables():
async with db_registry.async_session() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
await session.execute(table.delete()) # Truncate table
await session.commit()
# context manager now handles commits
# await session.commit()
asyncio.run(_clear())

View File

@@ -43,7 +43,8 @@ async def _clear_tables():
async with db_registry.async_session() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
await session.execute(table.delete()) # Truncate table
await session.commit()
# context manager now handles commits
# await session.commit()
@pytest.fixture(autouse=True)

View File

@@ -218,7 +218,8 @@ async def clear_tables():
async with db_registry.async_session() as session:
await session.execute(delete(SandboxEnvironmentVariable))
await session.execute(delete(SandboxConfig))
await session.commit()
# context manager now handles commits
# await session.commit()
# --------------------------------------------------------------------------------------------------------------------

View File

@@ -167,7 +167,8 @@ class TestMCPServerEncryption:
updated_at=datetime.now(timezone.utc),
)
session.add(db_server)
await session.commit()
# context manager now handles commits
# await session.commit()
# Retrieve server directly by ID to avoid issues with other servers in DB
test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
@@ -183,7 +184,8 @@ class TestMCPServerEncryption:
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id))
db_server = result.scalar_one()
await session.delete(db_server)
await session.commit()
# context manager now handles commits
# await session.commit()
finally:
# Restore original encryption key
@@ -338,7 +340,8 @@ class TestMCPOAuthEncryption:
updated_at=datetime.now(timezone.utc),
)
session.add(db_oauth)
await session.commit()
# context manager now handles commits
# await session.commit()
# Retrieve through manager by ID
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
@@ -456,7 +459,8 @@ class TestMCPOAuthEncryption:
updated_at=datetime.now(timezone.utc),
)
session.add(db_oauth)
await session.commit()
# context manager now handles commits
# await session.commit()
# Retrieve through manager
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
@@ -505,7 +509,8 @@ class TestMCPOAuthEncryption:
updated_at=datetime.now(timezone.utc),
)
session.add(db_oauth)
await session.commit()
# context manager now handles commits
# await session.commit()
# Retrieve through manager
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)