fix: duplicate session commit in step logging (#7512)
* fix: duplicate session commit in step logging * update all callsites
This commit is contained in:
@@ -899,7 +899,8 @@ class AgentManager:
|
||||
agent.message_ids = message_ids
|
||||
|
||||
await agent.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@trace_method
|
||||
async def list_agents_async(
|
||||
@@ -1216,7 +1217,8 @@ class AgentManager:
|
||||
await session.commit()
|
||||
for agent in agents_to_delete:
|
||||
await session.delete(agent)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception(f"Failed to hard delete Agent with ID {agent_id}")
|
||||
@@ -2570,7 +2572,8 @@ class AgentManager:
|
||||
agent.tool_rules = tool_rules
|
||||
session.add(agent)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -2643,7 +2646,8 @@ class AgentManager:
|
||||
else:
|
||||
logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}")
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -2767,7 +2771,8 @@ class AgentManager:
|
||||
else:
|
||||
logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}")
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -2804,7 +2809,8 @@ class AgentManager:
|
||||
else:
|
||||
logger.info(f"Detached all {detached_count} tools from agent {agent_id}")
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -2832,7 +2838,8 @@ class AgentManager:
|
||||
|
||||
agent.tool_rules = tool_rules
|
||||
session.add(agent)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -191,7 +191,8 @@ class ArchiveManager:
|
||||
is_owner=is_owner,
|
||||
)
|
||||
session.add(archives_agents)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -224,7 +225,8 @@ class ArchiveManager:
|
||||
else:
|
||||
logger.info(f"Detached agent {agent_id} from archive {archive_id}")
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@@ -609,6 +611,7 @@ class ArchiveManager:
|
||||
|
||||
# update the archive with the namespace
|
||||
await session.execute(update(ArchiveModel).where(ArchiveModel.id == archive_id).values(_vector_db_namespace=namespace_name))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return namespace_name
|
||||
|
||||
@@ -100,7 +100,8 @@ class BlockManager:
|
||||
block = BlockModel(**data, organization_id=actor.organization_id)
|
||||
await block.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_block = block.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_block
|
||||
|
||||
@enforce_types
|
||||
@@ -130,7 +131,8 @@ class BlockManager:
|
||||
items=block_models, db_session=session, actor=actor, no_commit=True, no_refresh=True
|
||||
)
|
||||
result = [m.to_pydantic() for m in created_models]
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@@ -150,7 +152,8 @@ class BlockManager:
|
||||
|
||||
await block.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_block = block.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_block
|
||||
|
||||
@enforce_types
|
||||
@@ -591,7 +594,8 @@ class BlockManager:
|
||||
new_val = new_val[: block.limit]
|
||||
block.value = new_val
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
if return_hydrated:
|
||||
# TODO: implement for async
|
||||
@@ -669,7 +673,8 @@ class BlockManager:
|
||||
|
||||
# 7) Flush changes, then commit once
|
||||
block = await block.update_async(db_session=session, actor=actor, no_commit=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return block.to_pydantic()
|
||||
|
||||
@@ -757,7 +762,8 @@ class BlockManager:
|
||||
block = await self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor)
|
||||
|
||||
# 4) Commit
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -805,5 +811,6 @@ class BlockManager:
|
||||
|
||||
block = await self._move_block_to_sequence(session, block, next_entry.sequence_number, actor)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return block.to_pydantic()
|
||||
|
||||
@@ -200,7 +200,8 @@ class FileAgentManager:
|
||||
stmt = delete(FileAgentModel).where(and_(or_(*conditions), FileAgentModel.organization_id == actor.organization_id))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return result.rowcount
|
||||
|
||||
@@ -405,7 +406,8 @@ class FileAgentManager:
|
||||
.values(last_accessed_at=func.now())
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -425,7 +427,8 @@ class FileAgentManager:
|
||||
.values(last_accessed_at=func.now())
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -458,7 +461,8 @@ class FileAgentManager:
|
||||
)
|
||||
|
||||
closed_file_names = [row.file_name for row in (await session.execute(stmt))]
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return closed_file_names
|
||||
|
||||
@enforce_types
|
||||
@@ -702,7 +706,8 @@ class FileAgentManager:
|
||||
.values(is_open=False, visible_content=None)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return closed_file_names
|
||||
|
||||
async def _get_association_by_file_id(self, session, agent_id: str, file_id: str, actor: PydanticUser) -> FileAgentModel:
|
||||
|
||||
@@ -246,7 +246,8 @@ class GroupManager:
|
||||
)
|
||||
await session.execute(delete_stmt)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@@ -434,7 +435,8 @@ class GroupManager:
|
||||
|
||||
# Add block to group
|
||||
session.add(GroupsBlocks(group_id=group_id, block_id=block_id))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@@ -452,7 +454,8 @@ class GroupManager:
|
||||
# Remove block from group
|
||||
delete_group_block = delete(GroupsBlocks).where(and_(GroupsBlocks.group_id == group_id, GroupsBlocks.block_id == block_id))
|
||||
await session.execute(delete_group_block)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@staticmethod
|
||||
def ensure_buffer_length_range_valid(
|
||||
|
||||
@@ -257,7 +257,8 @@ class IdentityManager:
|
||||
if identity.organization_id != actor.organization_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
await session.delete(identity)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -58,7 +58,8 @@ class JobManager:
|
||||
job.organization_id = actor.organization_id
|
||||
job = await job.create_async(session, actor=actor, no_commit=True, no_refresh=True) # Save job in the database
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Convert to pydantic first, then add agent_id if needed
|
||||
result = super(JobModel, job).to_pydantic()
|
||||
@@ -122,7 +123,8 @@ class JobManager:
|
||||
# Get the updated metadata for callback
|
||||
final_metadata = job.metadata_
|
||||
result = job.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Dispatch callback outside of database session if needed
|
||||
if needs_callback:
|
||||
@@ -143,7 +145,8 @@ class JobManager:
|
||||
job.callback_error = callback_result.get("callback_error")
|
||||
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
result = job.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return result
|
||||
|
||||
@@ -462,7 +465,8 @@ class JobManager:
|
||||
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
|
||||
job.ttft_ns = ttft_ns
|
||||
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record TTFT for job {job_id}: {e}")
|
||||
|
||||
@@ -475,7 +479,8 @@ class JobManager:
|
||||
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
|
||||
job.total_duration_ns = total_duration_ns
|
||||
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record response duration for job {job_id}: {e}")
|
||||
|
||||
|
||||
@@ -45,7 +45,8 @@ class LLMBatchManager:
|
||||
)
|
||||
await batch.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_batch = batch.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_batch
|
||||
|
||||
@enforce_types
|
||||
@@ -98,7 +99,8 @@ class LLMBatchManager:
|
||||
)
|
||||
|
||||
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -285,7 +287,8 @@ class LLMBatchManager:
|
||||
created_items = await LLMBatchItem.batch_create_async(orm_items, session, actor=actor, no_commit=True, no_refresh=True)
|
||||
|
||||
pydantic_items = [item.to_pydantic() for item in created_items]
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_items
|
||||
|
||||
@enforce_types
|
||||
@@ -421,7 +424,8 @@ class LLMBatchManager:
|
||||
|
||||
if mappings:
|
||||
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -257,7 +257,8 @@ class MCPManager:
|
||||
logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}")
|
||||
|
||||
# Commit deletions
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# 2. Update existing tools and add new tools
|
||||
for tool_name, current_tool in current_tool_map.items():
|
||||
@@ -394,7 +395,8 @@ class MCPManager:
|
||||
f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return mcp_server.to_pydantic()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
@@ -700,7 +702,8 @@ class MCPManager:
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
except NoResultFound:
|
||||
await session.rollback()
|
||||
raise ValueError(f"MCP server with id {mcp_server_id} not found.")
|
||||
|
||||
@@ -82,7 +82,8 @@ class MCPServerManager:
|
||||
MCPToolsModel.organization_id == actor.organization_id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
async def get_tool_ids_by_mcp_server(self, mcp_server_id: str, actor: PydanticUser) -> List[str]:
|
||||
@@ -348,7 +349,8 @@ class MCPServerManager:
|
||||
logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}")
|
||||
|
||||
# Commit deletions
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# 2. Update existing tools and add new tools
|
||||
for tool_name, current_tool in current_tool_map.items():
|
||||
@@ -489,7 +491,8 @@ class MCPServerManager:
|
||||
f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return mcp_server.to_pydantic()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
@@ -871,7 +874,8 @@ class MCPServerManager:
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
except NoResultFound:
|
||||
await session.rollback()
|
||||
raise ValueError(f"MCP server with id {mcp_server_id} not found.")
|
||||
|
||||
@@ -515,7 +515,8 @@ class MessageManager:
|
||||
async with db_registry.async_session() as session:
|
||||
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True)
|
||||
result = [msg.to_pydantic() for msg in created_messages]
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
|
||||
|
||||
@@ -673,7 +674,8 @@ class MessageManager:
|
||||
message = self._update_message_by_id_impl(message_id, message_update, actor, message)
|
||||
await message.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_message = message.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
|
||||
|
||||
@@ -979,7 +981,8 @@ class MessageManager:
|
||||
rowcount = result.rowcount
|
||||
|
||||
# 4) commit once
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# 5) delete from turbopuffer if enabled (outside of DB session)
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
@@ -1031,7 +1034,8 @@ class MessageManager:
|
||||
rowcount = result.rowcount
|
||||
|
||||
# commit once
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
if should_use_tpuf_for_messages() and agent_ids:
|
||||
try:
|
||||
|
||||
@@ -174,7 +174,8 @@ class ProviderManager:
|
||||
# Soft delete in provider table
|
||||
await existing_provider.delete_async(session, actor=actor)
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -88,7 +88,8 @@ class RunManager:
|
||||
num_steps=0, # Initialize to 0
|
||||
)
|
||||
await metrics.create_async(session)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return run.to_pydantic()
|
||||
|
||||
@@ -366,7 +367,8 @@ class RunManager:
|
||||
final_metadata = run.metadata_
|
||||
pydantic_run = run.to_pydantic()
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Update agent's last_stop_reason when run completes
|
||||
# Do this after run update is committed to database
|
||||
@@ -417,7 +419,8 @@ class RunManager:
|
||||
metrics.num_steps = num_steps
|
||||
metrics.tools_used = list(tools_used) if tools_used else None
|
||||
await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Dispatch callback outside of database session if needed
|
||||
if needs_callback:
|
||||
@@ -445,7 +448,8 @@ class RunManager:
|
||||
run.callback_error = callback_result.get("callback_error")
|
||||
pydantic_run = run.to_pydantic()
|
||||
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
return pydantic_run
|
||||
|
||||
|
||||
@@ -170,7 +170,8 @@ class SourceManager:
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
await session.execute(upsert_stmt)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# fetch results
|
||||
source_names = [source.name for source in source_data_list]
|
||||
|
||||
@@ -206,7 +206,8 @@ class StepManager:
|
||||
new_step = StepModel(**step_data)
|
||||
await new_step.create_async(session, no_commit=True, no_refresh=True)
|
||||
pydantic_step = new_step.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_step
|
||||
|
||||
@enforce_types
|
||||
@@ -266,7 +267,8 @@ class StepManager:
|
||||
raise Exception("Unauthorized")
|
||||
|
||||
step.tid = transaction_id
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -318,7 +320,8 @@ class StepManager:
|
||||
raise Exception("Unauthorized")
|
||||
|
||||
step.stop_reason = stop_reason
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return step
|
||||
|
||||
@enforce_types
|
||||
@@ -364,7 +367,8 @@ class StepManager:
|
||||
if stop_reason:
|
||||
step.stop_reason = stop_reason.stop_reason
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
pydantic_step = step.to_pydantic()
|
||||
# Send webhook notification for step completion outside the DB session
|
||||
webhook_service = WebhookService()
|
||||
@@ -415,7 +419,8 @@ class StepManager:
|
||||
if usage.completion_tokens_details:
|
||||
step.completion_tokens_details = usage.completion_tokens_details.model_dump()
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
pydantic_step = step.to_pydantic()
|
||||
# Send webhook notification for step completion outside the DB session
|
||||
webhook_service = WebhookService()
|
||||
@@ -455,7 +460,8 @@ class StepManager:
|
||||
if stop_reason:
|
||||
step.stop_reason = stop_reason.stop_reason
|
||||
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
pydantic_step = step.to_pydantic()
|
||||
# Send webhook notification for step completion outside the DB session
|
||||
webhook_service = WebhookService()
|
||||
|
||||
@@ -36,7 +36,8 @@ class TelemetryManager:
|
||||
provider_trace.response_json = json_loads(response_json_str)
|
||||
await provider_trace.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_provider_trace = provider_trace.to_pydantic()
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
return pydantic_provider_trace
|
||||
|
||||
|
||||
|
||||
@@ -1226,7 +1226,8 @@ class ToolManager:
|
||||
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"])
|
||||
|
||||
await session.execute(upsert_stmt)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# fetch results (includes both inserted and skipped tools)
|
||||
tool_names = [tool.name for tool in tool_data_list]
|
||||
|
||||
@@ -120,7 +120,8 @@ async def clear_tables():
|
||||
async with db_registry.async_session() as session:
|
||||
await session.execute(delete(SandboxEnvironmentVariable))
|
||||
await session.execute(delete(SandboxConfig))
|
||||
await session.commit() # Commit the deletion
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -43,7 +43,8 @@ async def _clear_tables():
|
||||
if table.name == "block_history":
|
||||
continue
|
||||
await session.execute(table.delete()) # Truncate table
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
|
||||
def _run_server():
|
||||
|
||||
@@ -85,7 +85,8 @@ async def clear_tables():
|
||||
async with db_registry.async_session() as session:
|
||||
await session.execute(delete(SandboxEnvironmentVariable))
|
||||
await session.execute(delete(SandboxConfig))
|
||||
await session.commit() # Commit the deletion
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -53,7 +53,8 @@ async def _clear_tables():
|
||||
async with db_registry.async_session() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables):
|
||||
await session.execute(table.delete())
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@@ -60,7 +60,8 @@ async def clear_tables():
|
||||
async with db_registry.async_session() as session:
|
||||
await session.execute(delete(SandboxEnvironmentVariable))
|
||||
await session.execute(delete(SandboxConfig))
|
||||
await session.commit() # Commit the deletion
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -2423,7 +2423,8 @@ async def test_list_tools_with_corrupted_tool(server: SyncServer, default_user,
|
||||
)
|
||||
|
||||
session.add(corrupted_tool)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
corrupted_tool_id = corrupted_tool.id
|
||||
|
||||
# Now try to list tools - it should still work and not include the corrupted tool
|
||||
|
||||
@@ -79,7 +79,8 @@ def _clear_tables():
|
||||
async with db_registry.async_session() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
await session.execute(table.delete()) # Truncate table
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
asyncio.run(_clear())
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ async def _clear_tables():
|
||||
async with db_registry.async_session() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
await session.execute(table.delete()) # Truncate table
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@@ -218,7 +218,8 @@ async def clear_tables():
|
||||
async with db_registry.async_session() as session:
|
||||
await session.execute(delete(SandboxEnvironmentVariable))
|
||||
await session.execute(delete(SandboxConfig))
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -167,7 +167,8 @@ class TestMCPServerEncryption:
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(db_server)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Retrieve server directly by ID to avoid issues with other servers in DB
|
||||
test_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
|
||||
@@ -183,7 +184,8 @@ class TestMCPServerEncryption:
|
||||
result = await session.execute(select(ORMMCPServer).where(ORMMCPServer.id == server_id))
|
||||
db_server = result.scalar_one()
|
||||
await session.delete(db_server)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
finally:
|
||||
# Restore original encryption key
|
||||
@@ -338,7 +340,8 @@ class TestMCPOAuthEncryption:
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(db_oauth)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Retrieve through manager by ID
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
@@ -456,7 +459,8 @@ class TestMCPOAuthEncryption:
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(db_oauth)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Retrieve through manager
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
@@ -505,7 +509,8 @@ class TestMCPOAuthEncryption:
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(db_oauth)
|
||||
await session.commit()
|
||||
# context manager now handles commits
|
||||
# await session.commit()
|
||||
|
||||
# Retrieve through manager
|
||||
test_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor=default_user)
|
||||
|
||||
Reference in New Issue
Block a user