fix: move db logic inside sessions (#2553)

This commit is contained in:
cthomas
2025-05-30 14:02:25 -07:00
committed by GitHub
parent 6928dc1705
commit 61ea680bdb
6 changed files with 27 additions and 24 deletions

View File

@@ -538,11 +538,10 @@ class AgentManager:
new_agent.message_ids = [msg.id for msg in init_messages]
await session.refresh(new_agent)
result = await new_agent.to_pydantic_async()
# Using the synchronous version since we don't have an async version yet
# If you implement an async version of create_many_messages, you can switch to that
await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor)
return await new_agent.to_pydantic_async()
return result
@enforce_types
def _generate_initial_message_sequence(
@@ -1700,14 +1699,14 @@ class AgentManager:
# Force rebuild of system prompt so that the agent is updated with passage count
# and recent passages and add system message alert to agent
await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
pydantic_agent = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
await self.append_system_message_async(
agent_id=agent_id,
content=DATA_SOURCE_ATTACH_ALERT,
actor=actor,
)
return await agent.to_pydantic_async()
return pydantic_agent
@trace_method
@enforce_types
@@ -2622,7 +2621,7 @@ class AgentManager:
result = await session.execute(query)
# Extract the tag values from the result
results = [row[0] for row in result.all()]
return results
return results
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION":

View File

@@ -108,7 +108,7 @@ class MCPManager:
organization_id=actor.organization_id,
)
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
@enforce_types
async def create_or_update_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
@@ -200,7 +200,7 @@ class MCPManager:
"mcp_server_name": mcp_server_name,
},
)
return mcp_server.to_pydantic()
return mcp_server.to_pydantic()
# @enforce_types
# async def delete_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> None:

View File

@@ -58,7 +58,7 @@ class MessageManager:
"""Fetch messages by ID and return them in the requested order. Async version of above function."""
async with db_registry.async_session() as session:
results = await MessageModel.read_multiple_async(db_session=session, identifiers=message_ids, actor=actor)
return self._get_messages_by_id_postprocess(results, message_ids)
return self._get_messages_by_id_postprocess(results, message_ids)
def _get_messages_by_id_postprocess(
self,

View File

@@ -340,7 +340,7 @@ class SandboxConfigManager:
async with db_registry.async_session() as session:
env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True))
await env_var.create_async(session, actor=actor)
return env_var.to_pydantic()
return env_var.to_pydantic()
@enforce_types
@trace_method

View File

@@ -31,7 +31,7 @@ class SourceManager:
source.organization_id = actor.organization_id
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
await source.create_async(session, actor=actor)
return source.to_pydantic()
return source.to_pydantic()
@enforce_types
@trace_method
@@ -152,7 +152,7 @@ class SourceManager:
file_metadata.organization_id = actor.organization_id
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
await file_metadata.create_async(session, actor=actor)
return file_metadata.to_pydantic()
return file_metadata.to_pydantic()
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types

View File

@@ -169,7 +169,7 @@ class ToolManager:
tool = ToolModel(**tool_data)
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
return tool.to_pydantic()
return tool.to_pydantic()
@enforce_types
@trace_method
@@ -239,6 +239,7 @@ class ToolManager:
@trace_method
async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
"""List all tools with optional pagination."""
tools_to_delete = []
async with db_registry.async_session() as session:
tools = await ToolModel.list_async(
db_session=session,
@@ -247,17 +248,20 @@ class ToolManager:
organization_id=actor.organization_id,
)
# Remove any malformed tools
results = []
for tool in tools:
try:
pydantic_tool = tool.to_pydantic()
results.append(pydantic_tool)
except (ValueError, ModuleNotFoundError, AttributeError) as e:
logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
logger.warning("Deleted tool: ")
logger.warning(tool.pretty_print_columns())
await self.delete_tool_by_id_async(tool.id, actor=actor)
# Remove any malformed tools
results = []
for tool in tools:
try:
pydantic_tool = tool.to_pydantic()
results.append(pydantic_tool)
except (ValueError, ModuleNotFoundError, AttributeError) as e:
tools_to_delete.append(tool)
logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
logger.warning("Deleted tool: ")
logger.warning(tool.pretty_print_columns())
for tool in tools_to_delete:
await self.delete_tool_by_id_async(tool.id, actor=actor)
return results