fix: move db logic inside sessions (#2553)
This commit is contained in:
@@ -538,11 +538,10 @@ class AgentManager:
|
||||
new_agent.message_ids = [msg.id for msg in init_messages]
|
||||
|
||||
await session.refresh(new_agent)
|
||||
result = await new_agent.to_pydantic_async()
|
||||
|
||||
# Using the synchronous version since we don't have an async version yet
|
||||
# If you implement an async version of create_many_messages, you can switch to that
|
||||
await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor)
|
||||
return await new_agent.to_pydantic_async()
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
def _generate_initial_message_sequence(
|
||||
@@ -1700,14 +1699,14 @@ class AgentManager:
|
||||
|
||||
# Force rebuild of system prompt so that the agent is updated with passage count
|
||||
# and recent passages and add system message alert to agent
|
||||
await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
|
||||
pydantic_agent = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
|
||||
await self.append_system_message_async(
|
||||
agent_id=agent_id,
|
||||
content=DATA_SOURCE_ATTACH_ALERT,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return await agent.to_pydantic_async()
|
||||
return pydantic_agent
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
@@ -2622,7 +2621,7 @@ class AgentManager:
|
||||
result = await session.execute(query)
|
||||
# Extract the tag values from the result
|
||||
results = [row[0] for row in result.all()]
|
||||
return results
|
||||
return results
|
||||
|
||||
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
|
||||
if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION":
|
||||
|
||||
@@ -108,7 +108,7 @@ class MCPManager:
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
|
||||
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
||||
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
||||
|
||||
@enforce_types
|
||||
async def create_or_update_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
||||
@@ -200,7 +200,7 @@ class MCPManager:
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
return mcp_server.to_pydantic()
|
||||
return mcp_server.to_pydantic()
|
||||
|
||||
# @enforce_types
|
||||
# async def delete_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> None:
|
||||
|
||||
@@ -58,7 +58,7 @@ class MessageManager:
|
||||
"""Fetch messages by ID and return them in the requested order. Async version of above function."""
|
||||
async with db_registry.async_session() as session:
|
||||
results = await MessageModel.read_multiple_async(db_session=session, identifiers=message_ids, actor=actor)
|
||||
return self._get_messages_by_id_postprocess(results, message_ids)
|
||||
return self._get_messages_by_id_postprocess(results, message_ids)
|
||||
|
||||
def _get_messages_by_id_postprocess(
|
||||
self,
|
||||
|
||||
@@ -340,7 +340,7 @@ class SandboxConfigManager:
|
||||
async with db_registry.async_session() as session:
|
||||
env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True))
|
||||
await env_var.create_async(session, actor=actor)
|
||||
return env_var.to_pydantic()
|
||||
return env_var.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -31,7 +31,7 @@ class SourceManager:
|
||||
source.organization_id = actor.organization_id
|
||||
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
|
||||
await source.create_async(session, actor=actor)
|
||||
return source.to_pydantic()
|
||||
return source.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -152,7 +152,7 @@ class SourceManager:
|
||||
file_metadata.organization_id = actor.organization_id
|
||||
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
|
||||
await file_metadata.create_async(session, actor=actor)
|
||||
return file_metadata.to_pydantic()
|
||||
return file_metadata.to_pydantic()
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
|
||||
@@ -169,7 +169,7 @@ class ToolManager:
|
||||
|
||||
tool = ToolModel(**tool_data)
|
||||
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
|
||||
return tool.to_pydantic()
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -239,6 +239,7 @@ class ToolManager:
|
||||
@trace_method
|
||||
async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination."""
|
||||
tools_to_delete = []
|
||||
async with db_registry.async_session() as session:
|
||||
tools = await ToolModel.list_async(
|
||||
db_session=session,
|
||||
@@ -247,17 +248,20 @@ class ToolManager:
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
|
||||
# Remove any malformed tools
|
||||
results = []
|
||||
for tool in tools:
|
||||
try:
|
||||
pydantic_tool = tool.to_pydantic()
|
||||
results.append(pydantic_tool)
|
||||
except (ValueError, ModuleNotFoundError, AttributeError) as e:
|
||||
logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
|
||||
logger.warning("Deleted tool: ")
|
||||
logger.warning(tool.pretty_print_columns())
|
||||
await self.delete_tool_by_id_async(tool.id, actor=actor)
|
||||
# Remove any malformed tools
|
||||
results = []
|
||||
for tool in tools:
|
||||
try:
|
||||
pydantic_tool = tool.to_pydantic()
|
||||
results.append(pydantic_tool)
|
||||
except (ValueError, ModuleNotFoundError, AttributeError) as e:
|
||||
tools_to_delete.append(tool)
|
||||
logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
|
||||
logger.warning("Deleted tool: ")
|
||||
logger.warning(tool.pretty_print_columns())
|
||||
|
||||
for tool in tools_to_delete:
|
||||
await self.delete_tool_by_id_async(tool.id, actor=actor)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user