From 2298b17b216f8c43bde27881b3dca3d890224124 Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 25 May 2025 22:17:01 -0700 Subject: [PATCH] feat(asyncify): migrate tools (#2427) --- letta/server/rest_api/routers/v1/agents.py | 12 ++-- letta/server/rest_api/routers/v1/tools.py | 25 ++++--- letta/services/agent_manager.py | 71 +++++++++++++++++++ .../services/helpers/agent_manager_helper.py | 44 ++++++++++++ letta/services/tool_manager.py | 28 ++++++-- 5 files changed, 156 insertions(+), 24 deletions(-) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 5d61aade..4ddb509d 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -270,7 +270,7 @@ def list_agent_tools( @router.patch("/{agent_id}/tools/attach/{tool_id}", response_model=AgentState, operation_id="attach_tool") -def attach_tool( +async def attach_tool( agent_id: str, tool_id: str, server: "SyncServer" = Depends(get_letta_server), @@ -279,12 +279,12 @@ def attach_tool( """ Attach a tool to an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor) @router.patch("/{agent_id}/tools/detach/{tool_id}", response_model=AgentState, operation_id="detach_tool") -def detach_tool( +async def detach_tool( agent_id: str, tool_id: str, server: "SyncServer" = Depends(get_letta_server), @@ -293,8 +293,8 @@ def detach_tool( """ Detach a tool from an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor) @router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 9dc6ffe4..591cc003 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -42,7 +42,7 @@ def delete_tool( @router.get("/count", response_model=int, operation_id="count_tools") -def count_tools( +async def count_tools( server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), include_base_tools: Optional[bool] = Query(False, description="Include built-in Letta tools in the count"), @@ -51,9 +51,8 @@ def count_tools( Get a count of all tools available to agents belonging to the org of the user. """ try: - return server.tool_manager.size( - actor=server.user_manager.get_user_or_default(user_id=actor_id), include_base_tools=include_base_tools - ) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.tool_manager.size_async(actor=actor, include_base_tools=include_base_tools) except Exception as e: print(f"Error occurred: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -100,7 +99,7 @@ async def list_tools( @router.post("/", response_model=Tool, operation_id="create_tool") -def create_tool( +async def create_tool( request: ToolCreate = Body(...), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -109,9 +108,9 @@ def create_tool( Create a new tool """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) tool = Tool(**request.model_dump()) - return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor) + return await server.tool_manager.create_tool_async(pydantic_tool=tool, actor=actor) except UniqueConstraintViolationError as e: # Log or print the full exception here for debugging print(f"Error occurred: {e}") @@ -132,7 +131,7 @@ def create_tool( @router.put("/", response_model=Tool, operation_id="upsert_tool") -def upsert_tool( +async def upsert_tool( request: ToolCreate = Body(...), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -141,8 +140,8 @@ def upsert_tool( Create or update a tool """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + tool = await server.tool_manager.create_or_update_tool_async(pydantic_tool=Tool(**request.model_dump()), actor=actor) return tool except UniqueConstraintViolationError as e: # Log the error and raise a conflict exception @@ -266,7 +265,7 @@ def list_composio_actions_by_app( @router.post("/composio/{composio_action_name}", response_model=Tool, operation_id="add_composio_tool") -def add_composio_tool( +async def add_composio_tool( composio_action_name: str, server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -274,11 +273,11 @@ def add_composio_tool( """ Add a new Composio tool by action name (Composio refers to each tool as an `Action`) """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: tool_create = ToolCreate.from_composio(action_name=composio_action_name) - return server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=actor) + return await server.tool_manager.create_or_update_composio_tool_async(tool_create=tool_create, actor=actor) except ConnectedAccountNotFoundError as e: raise HTTPException( status_code=400, # Bad Request diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index aeb75403..31712640 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -73,6 +73,7 @@ from letta.services.helpers.agent_manager_helper import ( _apply_pagination_async, _apply_tag_filter, _process_relationship, + _process_relationship_async, check_supports_structured_output, compile_system_message, derive_system_message, @@ -2385,6 +2386,42 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + @trace_method + @enforce_types + async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: + """ + Attaches a tool to an agent. + + Args: + agent_id: ID of the agent to attach the tool to. + tool_id: ID of the tool to attach. + actor: User performing the action. + + Raises: + NoResultFound: If the agent or tool is not found. + + Returns: + PydanticAgentState: The updated agent state. + """ + async with db_registry.async_session() as session: + # Verify the agent exists and user has permission to access it + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # Use the _process_relationship helper to attach the tool + await _process_relationship_async( + session=session, + agent=agent, + relationship_name="tools", + model_class=ToolModel, + item_ids=[tool_id], + allow_partial=False, # Ensure the tool exists + replace=False, # Extend the existing tools + ) + + # Commit and refresh the agent + await agent.update_async(session, actor=actor) + return await agent.to_pydantic_async() + @trace_method @enforce_types def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: @@ -2419,6 +2456,40 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + @trace_method + @enforce_types + async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: + """ + Detaches a tool from an agent. + + Args: + agent_id: ID of the agent to detach the tool from. + tool_id: ID of the tool to detach. + actor: User performing the action. + + Raises: + NoResultFound: If the agent or tool is not found. + + Returns: + PydanticAgentState: The updated agent state. + """ + async with db_registry.async_session() as session: + # Verify the agent exists and user has permission to access it + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # Filter out the tool to be detached + remaining_tools = [tool for tool in agent.tools if tool.id != tool_id] + + if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship + logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}") + + # Update the tools relationship + agent.tools = remaining_tools + + # Commit and refresh the agent + await agent.update_async(session, actor=actor) + return await agent.to_pydantic_async() + @trace_method @enforce_types def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]: diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index ef2bc7f7..fd4058ba 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -68,6 +68,50 @@ def _process_relationship( current_relationship.extend(new_items) +@trace_method +async def _process_relationship_async( + session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True +): + """ + Generalized function to handle relationships like tools, sources, and blocks using item IDs. + + Args: + session: The database session. + agent: The AgentModel instance. + relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources'). + model_class: The ORM class corresponding to the related items. + item_ids: List of IDs to set or update. + allow_partial: If True, allows missing items without raising errors. + replace: If True, replaces the entire relationship; otherwise, extends it. + + Raises: + ValueError: If `allow_partial` is False and some IDs are missing. + """ + current_relationship = getattr(agent, relationship_name, []) + if not item_ids: + if replace: + setattr(agent, relationship_name, []) + return + + # Retrieve models for the provided IDs + result = await session.execute(select(model_class).where(model_class.id.in_(item_ids))) + found_items = result.scalars().all() + + # Validate all items are found if allow_partial is False + if not allow_partial and len(found_items) != len(item_ids): + missing = set(item_ids) - {item.id for item in found_items} + raise NoResultFound(f"Items not found in {relationship_name}: {missing}") + + if replace: + # Replace the relationship + setattr(agent, relationship_name, found_items) + else: + # Extend the relationship (only add new items) + current_ids = {item.id for item in current_relationship} + new_items = [item for item in found_items if item.id not in current_ids] + current_relationship.extend(new_items) + + def _process_tags(agent: AgentModel, tags: List[str], replace=True): """ Handles tags for an agent. diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 36711fcc..ae2f7aed 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -125,6 +125,13 @@ class ToolManager: PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor ) + @enforce_types + @trace_method + async def create_or_update_composio_tool_async(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: + return await self.create_or_update_tool_async( + PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor + ) + @enforce_types @trace_method def create_or_update_langchain_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: @@ -250,13 +257,13 @@ class ToolManager: logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}") logger.warning("Deleted tool: ") logger.warning(tool.pretty_print_columns()) - self.delete_tool_by_id(tool.id, actor=actor) + await self.delete_tool_by_id_async(tool.id, actor=actor) return results @enforce_types @trace_method - def size( + async def size_async( self, actor: PydanticUser, include_base_tools: bool, @@ -266,10 +273,10 @@ class ToolManager: If include_builtin is True, it will also count the built-in tools. """ - with db_registry.session() as session: + async with db_registry.async_session() as session: if include_base_tools: - return ToolModel.size(db_session=session, actor=actor) - return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET) + return await ToolModel.size_async(db_session=session, actor=actor) + return await ToolModel.size_async(db_session=session, actor=actor, name=LETTA_TOOL_SET) @enforce_types @trace_method @@ -341,6 +348,17 @@ class ToolManager: except NoResultFound: raise ValueError(f"Tool with id {tool_id} not found.") + @enforce_types + @trace_method + async def delete_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> None: + """Delete a tool by its ID.""" + async with db_registry.async_session() as session: + try: + tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) + await tool.hard_delete_async(db_session=session, actor=actor) + except NoResultFound: + raise ValueError(f"Tool with id {tool_id} not found.") + @enforce_types @trace_method def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: