From 9a95a8f9767158ba09dae4360d8ae93312bc8e0e Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 18 Dec 2025 17:23:09 -0800 Subject: [PATCH] fix: duplicate session commit in step logging (#7512) * fix: duplicate session commit in step logging * update all callsites --- letta/services/agent_manager.py | 21 ++++++++++++------- letta/services/archive_manager.py | 9 +++++--- letta/services/block_manager.py | 21 ++++++++++++------- letta/services/files_agents_manager.py | 15 ++++++++----- letta/services/group_manager.py | 9 +++++--- letta/services/identity_manager.py | 3 ++- letta/services/job_manager.py | 15 ++++++++----- letta/services/llm_batch_manager.py | 12 +++++++---- letta/services/mcp_manager.py | 9 +++++--- letta/services/mcp_server_manager.py | 12 +++++++---- letta/services/message_manager.py | 12 +++++++---- letta/services/provider_manager.py | 3 ++- letta/services/run_manager.py | 12 +++++++---- letta/services/source_manager.py | 3 ++- letta/services/step_manager.py | 18 ++++++++++------ letta/services/telemetry_manager.py | 3 ++- letta/services/tool_manager.py | 3 ++- tests/integration_test_async_tool_sandbox.py | 3 ++- tests/integration_test_batch_api_cron_jobs.py | 3 ++- tests/integration_test_modal.py | 3 ++- tests/integration_test_token_counters.py | 3 ++- ...integration_test_tool_execution_sandbox.py | 3 ++- tests/managers/test_tool_manager.py | 3 ++- tests/test_agent_serialization.py | 3 ++- tests/test_agent_serialization_v2.py | 3 ++- tests/test_client.py | 3 ++- tests/test_mcp_encryption.py | 15 ++++++++----- 27 files changed, 148 insertions(+), 74 deletions(-) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d0dd21a0..16648779 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -899,7 +899,8 @@ class AgentManager: agent.message_ids = message_ids await agent.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) - await session.commit() + # context manager now handles commits + # await session.commit() @trace_method async def list_agents_async( @@ -1216,7 +1217,8 @@ class AgentManager: await session.commit() for agent in agents_to_delete: await session.delete(agent) - await session.commit() + # context manager now handles commits + # await session.commit() except Exception as e: await session.rollback() logger.exception(f"Failed to hard delete Agent with ID {agent_id}") @@ -2570,7 +2572,8 @@ class AgentManager: agent.tool_rules = tool_rules session.add(agent) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @@ -2643,7 +2646,8 @@ class AgentManager: else: logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}") - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @trace_method @@ -2767,7 +2771,8 @@ class AgentManager: else: logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}") - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @@ -2804,7 +2809,8 @@ class AgentManager: else: logger.info(f"Detached all {detached_count} tools from agent {agent_id}") - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @@ -2832,7 +2838,8 @@ class AgentManager: agent.tool_rules = tool_rules session.add(agent) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @trace_method diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index 79c09c8e..6e1bf0ec 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -191,7 +191,8 @@ class ArchiveManager: is_owner=is_owner, ) session.add(archives_agents) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @@ -224,7 +225,8 @@ class ArchiveManager: else: logger.info(f"Detached agent {agent_id} from archive {archive_id}") - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) @@ -609,6 +611,7 @@ class ArchiveManager: # update the archive with the namespace await session.execute(update(ArchiveModel).where(ArchiveModel.id == archive_id).values(_vector_db_namespace=namespace_name)) - await session.commit() + # context manager now handles commits + # await session.commit() return namespace_name diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 1ed2392f..4afff931 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -100,7 +100,8 @@ class BlockManager: block = BlockModel(**data, organization_id=actor.organization_id) await block.create_async(session, actor=actor, no_commit=True, no_refresh=True) pydantic_block = block.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() return pydantic_block @enforce_types @@ -130,7 +131,8 @@ class BlockManager: items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True ) result = [m.to_pydantic() for m in created_models] - await session.commit() + # context manager now handles commits + # await session.commit() return result @enforce_types @@ -150,7 +152,8 @@ class BlockManager: await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) pydantic_block = block.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() return pydantic_block @enforce_types @@ -591,7 +594,8 @@ class BlockManager: new_val = new_val[: block.limit] block.value = new_val - await session.commit() + # context manager now handles commits + # await session.commit() if return_hydrated: # TODO: implement for async @@ -669,7 +673,8 @@ class BlockManager: # 7) Flush changes, then commit once block = await block.update_async(db_session=session, actor=actor, no_commit=True) - await session.commit() + # context manager now handles commits + # await session.commit() return block.to_pydantic() @@ -757,7 +762,8 @@ class BlockManager: block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor) # 4) Commit - await session.commit() + # context manager now handles commits + # await session.commit() return block.to_pydantic() @enforce_types @@ -805,5 +811,6 @@ class BlockManager: block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor) - await session.commit() + # context manager now handles commits + # await session.commit() return block.to_pydantic() diff --git a/letta/services/files_agents_manager.py b/letta/services/files_agents_manager.py index f993f89b..13013f70 100644 --- a/letta/services/files_agents_manager.py +++ b/letta/services/files_agents_manager.py @@ -200,7 +200,8 @@ class FileAgentManager: stmt = delete(FileAgentModel).where(and_(or_(*conditions), FileAgentModel.organization_id == actor.organization_id)) result = await session.execute(stmt) - await session.commit() + # context manager now handles commits + # await session.commit() return result.rowcount @@ -405,7 +406,8 @@ class FileAgentManager: .values(last_accessed_at=func.now()) ) await session.execute(stmt) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @trace_method @@ -425,7 +427,8 @@ class FileAgentManager: .values(last_accessed_at=func.now()) ) await session.execute(stmt) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @trace_method @@ -458,7 +461,8 @@ class FileAgentManager: ) closed_file_names = [row.file_name for row in (await session.execute(stmt))] - await session.commit() + # context manager now handles commits + # await session.commit() return closed_file_names @enforce_types @@ -702,7 +706,8 @@ class FileAgentManager: .values(is_open=False, visible_content=None) ) - await session.commit() + # context manager now handles commits + # await session.commit() return closed_file_names async def _get_association_by_file_id(self, session, agent_id: str, file_id: str, actor: PydanticUser) -> FileAgentModel: diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 6074302f..aaf7479e 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -246,7 +246,8 @@ class GroupManager: ) await session.execute(delete_stmt) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) @@ -434,7 +435,8 @@ class GroupManager: # Add block to group session.add(GroupsBlocks(group_id=group_id, block_id=block_id)) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) @@ -452,7 +454,8 @@ class GroupManager: # Remove block from group delete_group_block = delete(GroupsBlocks).where(and_(GroupsBlocks.group_id == group_id, GroupsBlocks.block_id == block_id)) await session.execute(delete_group_block) - await session.commit() + # context manager now handles commits + # await session.commit() @staticmethod def ensure_buffer_length_range_valid( diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 91adebea..40545e9f 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -257,7 +257,8 @@ class IdentityManager: if identity.organization_id != actor.organization_id: raise HTTPException(status_code=403, detail="Forbidden") await session.delete(identity) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @trace_method diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index ddabc80f..0b39ece1 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -58,7 +58,8 @@ class JobManager: job.organization_id = actor.organization_id job = await job.create_async(session, actor=actor, no_commit=True, no_refresh=True) # Save job in the database - await session.commit() + # context manager now handles commits + # await session.commit() # Convert to pydantic first, then add agent_id if needed result = super(JobModel, job).to_pydantic() @@ -122,7 +123,8 @@ class JobManager: # Get the updated metadata for callback final_metadata = job.metadata_ result = job.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() # Dispatch callback outside of database session if needed if needs_callback: @@ -143,7 +145,8 @@ class JobManager: job.callback_error = callback_result.get("callback_error") await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) result = job.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() return result @@ -462,7 +465,8 @@ class JobManager: job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"]) job.ttft_ns = ttft_ns await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) - await session.commit() + # context manager now handles commits + # await session.commit() except Exception as e: logger.warning(f"Failed to record TTFT for job {job_id}: {e}") @@ -475,7 +479,8 @@ class JobManager: job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"]) job.total_duration_ns = total_duration_ns await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) - await session.commit() + # context manager now handles commits + # await session.commit() except Exception as e: logger.warning(f"Failed to record response duration for job {job_id}: {e}") diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index 50f560da..6c09c2be 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -45,7 +45,8 @@ class LLMBatchManager: ) await batch.create_async(session, actor=actor, no_commit=True, no_refresh=True) pydantic_batch = batch.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() return pydantic_batch @enforce_types @@ -98,7 +99,8 @@ class LLMBatchManager: ) await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings)) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @trace_method @@ -285,7 +287,8 @@ class LLMBatchManager: created_items = await LLMBatchItem.batch_create_async(orm_items, session, actor=actor, no_commit=True, no_refresh=True) pydantic_items = [item.to_pydantic() for item in created_items] - await session.commit() + # context manager now handles commits + # await session.commit() return pydantic_items @enforce_types @@ -421,7 +424,8 @@ class LLMBatchManager: if mappings: await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings)) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @trace_method diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index e0d1c838..cfcf1e33 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -257,7 +257,8 @@ class MCPManager: logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}") # Commit deletions - await session.commit() + # context manager now handles commits + # await session.commit() # 2. Update existing tools and add new tools for tool_name, current_tool in current_tool_map.items(): @@ -394,7 +395,8 @@ class MCPManager: f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}" ) - await session.commit() + # context manager now handles commits + # await session.commit() return mcp_server.to_pydantic() except Exception as e: await session.rollback() @@ -700,7 +702,8 @@ class MCPManager: ) ) - await session.commit() + # context manager now handles commits + # await session.commit() except NoResultFound: await session.rollback() raise ValueError(f"MCP server with id {mcp_server_id} not found.") diff --git a/letta/services/mcp_server_manager.py b/letta/services/mcp_server_manager.py index bea9f854..a81c1ec0 100644 --- a/letta/services/mcp_server_manager.py +++ b/letta/services/mcp_server_manager.py @@ -82,7 +82,8 @@ class MCPServerManager: MCPToolsModel.organization_id == actor.organization_id, ) ) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types async def get_tool_ids_by_mcp_server(self, mcp_server_id: str, actor: PydanticUser) -> List[str]: @@ -348,7 +349,8 @@ class MCPServerManager: logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}") # Commit deletions - await session.commit() + # context manager now handles commits + # await session.commit() # 2. Update existing tools and add new tools for tool_name, current_tool in current_tool_map.items(): @@ -489,7 +491,8 @@ class MCPServerManager: f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}" ) - await session.commit() + # context manager now handles commits + # await session.commit() return mcp_server.to_pydantic() except Exception as e: await session.rollback() @@ -871,7 +874,8 @@ class MCPServerManager: ) ) - await session.commit() + # context manager now handles commits + # await session.commit() except NoResultFound: await session.rollback() raise ValueError(f"MCP server with id {mcp_server_id} not found.") diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index f0ee955f..6f45f2ad 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -515,7 +515,8 @@ class MessageManager: async with db_registry.async_session() as session: created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True) result = [msg.to_pydantic() for msg in created_messages] - await session.commit() + # context manager now handles commits + # await session.commit() from letta.helpers.tpuf_client import should_use_tpuf_for_messages @@ -673,7 +674,8 @@ class MessageManager: message = self._update_message_by_id_impl(message_id, message_update, actor, message) await message.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) pydantic_message = message.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() from letta.helpers.tpuf_client import should_use_tpuf_for_messages @@ -979,7 +981,8 @@ class MessageManager: rowcount = result.rowcount # 4) commit once - await session.commit() + # context manager now handles commits + # await session.commit() # 5) delete from turbopuffer if enabled (outside of DB session) from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages @@ -1031,7 +1034,8 @@ class MessageManager: rowcount = result.rowcount # commit once - await session.commit() + # context manager now handles commits + # await session.commit() if should_use_tpuf_for_messages() and agent_ids: try: diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 8fbaf043..7b1d685c 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -174,7 +174,8 @@ class ProviderManager: # Soft delete in provider table await existing_provider.delete_async(session, actor=actor) - await session.commit() + # context manager now handles commits + # await session.commit() @enforce_types @trace_method diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index 6140d24c..c05b8da3 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -88,7 +88,8 @@ class RunManager: num_steps=0, # Initialize to 0 ) await metrics.create_async(session) - await session.commit() + # context manager now handles commits + # await session.commit() return run.to_pydantic() @@ -366,7 +367,8 @@ class RunManager: final_metadata = run.metadata_ pydantic_run = run.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() # Update agent's last_stop_reason when run completes # Do this after run update is committed to database @@ -417,7 +419,8 @@ class RunManager: metrics.num_steps = num_steps metrics.tools_used = list(tools_used) if tools_used else None await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) - await session.commit() + # context manager now handles commits + # await session.commit() # Dispatch callback outside of database session if needed if needs_callback: @@ -445,7 +448,8 @@ class RunManager: run.callback_error = callback_result.get("callback_error") pydantic_run = run.to_pydantic() await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) - await session.commit() + # context manager now handles commits + # await session.commit() return pydantic_run diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 9099e822..af050cf6 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -170,7 +170,8 @@ class SourceManager: upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict) await session.execute(upsert_stmt) - await session.commit() + # context manager now handles commits + # await session.commit() # fetch results source_names = [source.name for source in source_data_list] diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 64a1e4d3..92733e75 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -206,7 +206,8 @@ class StepManager: new_step = StepModel(**step_data) await new_step.create_async(session, no_commit=True, no_refresh=True) pydantic_step = new_step.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() return pydantic_step @enforce_types @@ -266,7 +267,8 @@ class StepManager: raise Exception("Unauthorized") step.tid = transaction_id - await session.commit() + # context manager now handles commits + # await session.commit() return step.to_pydantic() @enforce_types @@ -318,7 +320,8 @@ class StepManager: raise Exception("Unauthorized") step.stop_reason = stop_reason - await session.commit() + # context manager now handles commits + # await session.commit() return step @enforce_types @@ -364,7 +367,8 @@ class StepManager: if stop_reason: step.stop_reason = stop_reason.stop_reason - await session.commit() + # context manager now handles commits + # await session.commit() pydantic_step = step.to_pydantic() # Send webhook notification for step completion outside the DB session webhook_service = WebhookService() @@ -415,7 +419,8 @@ class StepManager: if usage.completion_tokens_details: step.completion_tokens_details = usage.completion_tokens_details.model_dump() - await session.commit() + # context manager now handles commits + # await session.commit() pydantic_step = step.to_pydantic() # Send webhook notification for step completion outside the DB session webhook_service = WebhookService() @@ -455,7 +460,8 @@ class StepManager: if stop_reason: step.stop_reason = stop_reason.stop_reason - await session.commit() + # context manager now handles commits + # await session.commit() pydantic_step = step.to_pydantic() # Send webhook notification for step completion outside the DB session webhook_service = WebhookService() diff --git a/letta/services/telemetry_manager.py b/letta/services/telemetry_manager.py index c74a9cd6..e7b7b5df 100644 --- a/letta/services/telemetry_manager.py +++ b/letta/services/telemetry_manager.py @@ -36,7 +36,8 @@ class TelemetryManager: provider_trace.response_json = json_loads(response_json_str) await provider_trace.create_async(session, actor=actor, no_commit=True, no_refresh=True) pydantic_provider_trace = provider_trace.to_pydantic() - await session.commit() + # context manager now handles commits + # await session.commit() return pydantic_provider_trace diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index de3ff428..d5f3af6d 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -1226,7 +1226,8 @@ class ToolManager: upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"]) await session.execute(upsert_stmt) - await session.commit() + # context manager now handles commits + # await session.commit() # fetch results (includes both inserted and skipped tools) tool_names = [tool.name for tool in tool_data_list] diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index a340b93d..5b6b872e 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -120,7 +120,8 @@ async def clear_tables(): async with db_registry.async_session() as session: await session.execute(delete(SandboxEnvironmentVariable)) await session.execute(delete(SandboxConfig)) - await session.commit() # Commit the deletion + # context manager now handles commits + # await session.commit() @pytest.fixture diff --git a/tests/integration_test_batch_api_cron_jobs.py b/tests/integration_test_batch_api_cron_jobs.py index bae4f857..ab072b7a 100644 --- a/tests/integration_test_batch_api_cron_jobs.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -43,7 +43,8 @@ async def _clear_tables(): if table.name == "block_history": continue await session.execute(table.delete()) # Truncate table - await session.commit() + # context manager now handles commits + # await session.commit() def _run_server(): diff --git a/tests/integration_test_modal.py b/tests/integration_test_modal.py index 1689097d..8a2bb3cb 100644 --- a/tests/integration_test_modal.py +++ b/tests/integration_test_modal.py @@ -85,7 +85,8 @@ async def clear_tables(): async with db_registry.async_session() as session: await session.execute(delete(SandboxEnvironmentVariable)) await session.execute(delete(SandboxConfig)) - await session.commit() # Commit the deletion + # context manager now handles commits + # await session.commit() @pytest.fixture diff --git a/tests/integration_test_token_counters.py b/tests/integration_test_token_counters.py index cb05edbf..21dbb17e 100644 --- a/tests/integration_test_token_counters.py +++ b/tests/integration_test_token_counters.py @@ -53,7 +53,8 @@ async def _clear_tables(): async with db_registry.async_session() as session: for table in reversed(Base.metadata.sorted_tables): await session.execute(table.delete()) - await session.commit() + # context manager now handles commits + # await session.commit() @pytest.fixture(autouse=True) diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 928fc011..c81ae5d3 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -60,7 +60,8 @@ async def clear_tables(): async with db_registry.async_session() as session: await session.execute(delete(SandboxEnvironmentVariable)) await session.execute(delete(SandboxConfig)) - await session.commit() # Commit the deletion + # context manager now handles commits + # await session.commit() @pytest.fixture diff --git a/tests/managers/test_tool_manager.py b/tests/managers/test_tool_manager.py index 9861029e..4658acb8 100644 --- a/tests/managers/test_tool_manager.py +++ b/tests/managers/test_tool_manager.py @@ -2423,7 +2423,8 @@ async def test_list_tools_with_corrupted_tool(server: SyncServer, default_user, ) session.add(corrupted_tool) - await session.commit() + # context manager now handles commits + # await session.commit() corrupted_tool_id = corrupted_tool.id # Now try to list tools - it should still work and not include the corrupted tool diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index 7957e4ef..12235b77 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -79,7 +79,8 @@ def _clear_tables(): async with db_registry.async_session() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues await session.execute(table.delete()) # Truncate table - await session.commit() + # context manager now handles commits + # await session.commit() asyncio.run(_clear()) diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 60692c97..26b2d966 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -43,7 +43,8 @@ async def _clear_tables(): async with db_registry.async_session() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues await session.execute(table.delete()) # Truncate table - await session.commit() + # context manager now handles commits + # await session.commit() @pytest.fixture(autouse=True) diff --git a/tests/test_client.py b/tests/test_client.py index 386620cf..96134e35 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -218,7 +218,8 @@ async def clear_tables(): async with db_registry.async_session() as session: await session.execute(delete(SandboxEnvironmentVariable)) await session.execute(delete(SandboxConfig)) - await session.commit() + # context manager now handles commits + # await session.commit() # -------------------------------------------------------------------------------------------------------------------- diff --git a/tests/test_mcp_encryption.py b/tests/test_mcp_encryption.py index 1e5c046f..e8c9c0f1 100644 --- a/tests/test_mcp_encryption.py +++ b/tests/test_mcp_encryption.py @@ -167,7 +167,8 @@ class TestMCPServerEncryption: updated_at=datetime.now(timezone.utc), ) session.add(db_server) - await session.commit() + # context manager now handles commits + # await session.commit() # Retrieve server directly by ID to avoid issues with other servers in DB test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user) @@ -183,7 +184,8 @@ class TestMCPServerEncryption: result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id)) db_server = result.scalar_one() await session.delete(db_server) - await session.commit() + # context manager now handles commits + # await session.commit() finally: # Restore original encryption key @@ -338,7 +340,8 @@ class TestMCPOAuthEncryption: updated_at=datetime.now(timezone.utc), ) session.add(db_oauth) - await session.commit() + # context manager now handles commits + # await session.commit() # Retrieve through manager by ID test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) @@ -456,7 +459,8 @@ class TestMCPOAuthEncryption: updated_at=datetime.now(timezone.utc), ) session.add(db_oauth) - await session.commit() + # context manager now handles commits + # await session.commit() # Retrieve through manager test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user) @@ -505,7 +509,8 @@ class TestMCPOAuthEncryption: updated_at=datetime.now(timezone.utc), ) session.add(db_oauth) - await session.commit() + # context manager now handles commits + # await session.commit() # Retrieve through manager test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)