diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 344bd221..99349070 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -538,11 +538,10 @@ class AgentManager: new_agent.message_ids = [msg.id for msg in init_messages] await session.refresh(new_agent) + result = await new_agent.to_pydantic_async() - # Using the synchronous version since we don't have an async version yet - # If you implement an async version of create_many_messages, you can switch to that await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor) - return await new_agent.to_pydantic_async() + return result @enforce_types def _generate_initial_message_sequence( @@ -1700,14 +1699,14 @@ class AgentManager: # Force rebuild of system prompt so that the agent is updated with passage count # and recent passages and add system message alert to agent - await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) + pydantic_agent = await self.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) await self.append_system_message_async( agent_id=agent_id, content=DATA_SOURCE_ATTACH_ALERT, actor=actor, ) - return await agent.to_pydantic_async() + return pydantic_agent @trace_method @enforce_types @@ -2622,7 +2621,7 @@ class AgentManager: result = await session.execute(query) # Extract the tag values from the result results = [row[0] for row in result.all()] - return results + return results async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION": diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 7e93638b..a66f536f 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -108,7 +108,7 @@ class MCPManager: organization_id=actor.organization_id, ) - return [mcp_server.to_pydantic() for mcp_server in mcp_servers] + return [mcp_server.to_pydantic() for mcp_server in mcp_servers] @enforce_types async def create_or_update_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: @@ -200,7 +200,7 @@ class MCPManager: "mcp_server_name": mcp_server_name, }, ) - return mcp_server.to_pydantic() + return mcp_server.to_pydantic() # @enforce_types # async def delete_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> None: diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 09c38ecf..01763689 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -58,7 +58,7 @@ class MessageManager: """Fetch messages by ID and return them in the requested order. Async version of above function.""" async with db_registry.async_session() as session: results = await MessageModel.read_multiple_async(db_session=session, identifiers=message_ids, actor=actor) - return self._get_messages_by_id_postprocess(results, message_ids) + return self._get_messages_by_id_postprocess(results, message_ids) def _get_messages_by_id_postprocess( self, diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index a8d17fb0..211f6975 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -340,7 +340,7 @@ class SandboxConfigManager: async with db_registry.async_session() as session: env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True)) await env_var.create_async(session, actor=actor) - return env_var.to_pydantic() + return env_var.to_pydantic() @enforce_types @trace_method diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 94ca69f5..7580f77e 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -31,7 +31,7 @@ class SourceManager: source.organization_id = actor.organization_id source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True)) await source.create_async(session, actor=actor) - return source.to_pydantic() + return source.to_pydantic() @enforce_types @trace_method @@ -152,7 +152,7 @@ class SourceManager: file_metadata.organization_id = actor.organization_id file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True)) await file_metadata.create_async(session, actor=actor) - return file_metadata.to_pydantic() + return file_metadata.to_pydantic() # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index ae2f7aed..78652d4d 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -169,7 +169,7 @@ class ToolManager: tool = ToolModel(**tool_data) await tool.create_async(session, actor=actor) # Re-raise other database-related errors - return tool.to_pydantic() + return tool.to_pydantic() @enforce_types @trace_method @@ -239,6 +239,7 @@ class ToolManager: @trace_method async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: """List all tools with optional pagination.""" + tools_to_delete = [] async with db_registry.async_session() as session: tools = await ToolModel.list_async( db_session=session, @@ -247,17 +248,20 @@ class ToolManager: organization_id=actor.organization_id, ) - # Remove any malformed tools - results = [] - for tool in tools: - try: - pydantic_tool = tool.to_pydantic() - results.append(pydantic_tool) - except (ValueError, ModuleNotFoundError, AttributeError) as e: - logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}") - logger.warning("Deleted tool: ") - logger.warning(tool.pretty_print_columns()) - await self.delete_tool_by_id_async(tool.id, actor=actor) + # Remove any malformed tools + results = [] + for tool in tools: + try: + pydantic_tool = tool.to_pydantic() + results.append(pydantic_tool) + except (ValueError, ModuleNotFoundError, AttributeError) as e: + tools_to_delete.append(tool) + logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}") + logger.warning("Deleted tool: ") + logger.warning(tool.pretty_print_columns()) + + for tool in tools_to_delete: + await self.delete_tool_by_id_async(tool.id, actor=actor) return results