feat(asyncify): migrate tools (#2427)

This commit is contained in:
cthomas
2025-05-25 22:17:01 -07:00
committed by GitHub
parent bfc247b2d1
commit 2298b17b21
5 changed files with 156 additions and 24 deletions

View File

@@ -270,7 +270,7 @@ def list_agent_tools(
@router.patch("/{agent_id}/tools/attach/{tool_id}", response_model=AgentState, operation_id="attach_tool")
def attach_tool(
async def attach_tool(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
@@ -279,12 +279,12 @@ def attach_tool(
"""
Attach a tool to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
@router.patch("/{agent_id}/tools/detach/{tool_id}", response_model=AgentState, operation_id="detach_tool")
def detach_tool(
async def detach_tool(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
@@ -293,8 +293,8 @@ def detach_tool(
"""
Detach a tool from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")

View File

@@ -42,7 +42,7 @@ def delete_tool(
@router.get("/count", response_model=int, operation_id="count_tools")
def count_tools(
async def count_tools(
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
include_base_tools: Optional[bool] = Query(False, description="Include built-in Letta tools in the count"),
@@ -51,9 +51,8 @@ def count_tools(
Get a count of all tools available to agents belonging to the org of the user.
"""
try:
return server.tool_manager.size(
actor=server.user_manager.get_user_or_default(user_id=actor_id), include_base_tools=include_base_tools
)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.tool_manager.size_async(actor=actor, include_base_tools=include_base_tools)
except Exception as e:
print(f"Error occurred: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -100,7 +99,7 @@ async def list_tools(
@router.post("/", response_model=Tool, operation_id="create_tool")
def create_tool(
async def create_tool(
request: ToolCreate = Body(...),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -109,9 +108,9 @@ def create_tool(
Create a new tool
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
tool = Tool(**request.model_dump())
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
return await server.tool_manager.create_tool_async(pydantic_tool=tool, actor=actor)
except UniqueConstraintViolationError as e:
# Log or print the full exception here for debugging
print(f"Error occurred: {e}")
@@ -132,7 +131,7 @@ def create_tool(
@router.put("/", response_model=Tool, operation_id="upsert_tool")
def upsert_tool(
async def upsert_tool(
request: ToolCreate = Body(...),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
@@ -141,8 +140,8 @@ def upsert_tool(
Create or update a tool
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
tool = await server.tool_manager.create_or_update_tool_async(pydantic_tool=Tool(**request.model_dump()), actor=actor)
return tool
except UniqueConstraintViolationError as e:
# Log the error and raise a conflict exception
@@ -266,7 +265,7 @@ def list_composio_actions_by_app(
@router.post("/composio/{composio_action_name}", response_model=Tool, operation_id="add_composio_tool")
def add_composio_tool(
async def add_composio_tool(
composio_action_name: str,
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
@@ -274,11 +273,11 @@ def add_composio_tool(
"""
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
tool_create = ToolCreate.from_composio(action_name=composio_action_name)
return server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=actor)
return await server.tool_manager.create_or_update_composio_tool_async(tool_create=tool_create, actor=actor)
except ConnectedAccountNotFoundError as e:
raise HTTPException(
status_code=400, # Bad Request

View File

@@ -73,6 +73,7 @@ from letta.services.helpers.agent_manager_helper import (
_apply_pagination_async,
_apply_tag_filter,
_process_relationship,
_process_relationship_async,
check_supports_structured_output,
compile_system_message,
derive_system_message,
@@ -2385,6 +2386,42 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a tool to an agent.
Args:
agent_id: ID of the agent to attach the tool to.
tool_id: ID of the tool to attach.
actor: User performing the action.
Raises:
NoResultFound: If the agent or tool is not found.
Returns:
PydanticAgentState: The updated agent state.
"""
async with db_registry.async_session() as session:
# Verify the agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Use the _process_relationship helper to attach the tool
await _process_relationship_async(
session=session,
agent=agent,
relationship_name="tools",
model_class=ToolModel,
item_ids=[tool_id],
allow_partial=False, # Ensure the tool exists
replace=False, # Extend the existing tools
)
# Commit and refresh the agent
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
@trace_method
@enforce_types
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
@@ -2419,6 +2456,40 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
@trace_method
@enforce_types
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a tool from an agent.
Args:
agent_id: ID of the agent to detach the tool from.
tool_id: ID of the tool to detach.
actor: User performing the action.
Raises:
NoResultFound: If the agent or tool is not found.
Returns:
PydanticAgentState: The updated agent state.
"""
async with db_registry.async_session() as session:
# Verify the agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Filter out the tool to be detached
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
# Update the tools relationship
agent.tools = remaining_tools
# Commit and refresh the agent
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
@trace_method
@enforce_types
def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:

View File

@@ -68,6 +68,50 @@ def _process_relationship(
current_relationship.extend(new_items)
@trace_method
async def _process_relationship_async(
session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
):
"""
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
Args:
session: The database session.
agent: The AgentModel instance.
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
model_class: The ORM class corresponding to the related items.
item_ids: List of IDs to set or update.
allow_partial: If True, allows missing items without raising errors.
replace: If True, replaces the entire relationship; otherwise, extends it.
Raises:
ValueError: If `allow_partial` is False and some IDs are missing.
"""
current_relationship = getattr(agent, relationship_name, [])
if not item_ids:
if replace:
setattr(agent, relationship_name, [])
return
# Retrieve models for the provided IDs
result = await session.execute(select(model_class).where(model_class.id.in_(item_ids)))
found_items = result.scalars().all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(item_ids):
missing = set(item_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
if replace:
# Replace the relationship
setattr(agent, relationship_name, found_items)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}
new_items = [item for item in found_items if item.id not in current_ids]
current_relationship.extend(new_items)
def _process_tags(agent: AgentModel, tags: List[str], replace=True):
"""
Handles tags for an agent.

View File

@@ -125,6 +125,13 @@ class ToolManager:
PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor
)
@enforce_types
@trace_method
async def create_or_update_composio_tool_async(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
return await self.create_or_update_tool_async(
PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor
)
@enforce_types
@trace_method
def create_or_update_langchain_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
@@ -250,13 +257,13 @@ class ToolManager:
logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
logger.warning("Deleted tool: ")
logger.warning(tool.pretty_print_columns())
self.delete_tool_by_id(tool.id, actor=actor)
await self.delete_tool_by_id_async(tool.id, actor=actor)
return results
@enforce_types
@trace_method
def size(
async def size_async(
self,
actor: PydanticUser,
include_base_tools: bool,
@@ -266,10 +273,10 @@ class ToolManager:
If include_builtin is True, it will also count the built-in tools.
"""
with db_registry.session() as session:
async with db_registry.async_session() as session:
if include_base_tools:
return ToolModel.size(db_session=session, actor=actor)
return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET)
return await ToolModel.size_async(db_session=session, actor=actor)
return await ToolModel.size_async(db_session=session, actor=actor, name=LETTA_TOOL_SET)
@enforce_types
@trace_method
@@ -341,6 +348,17 @@ class ToolManager:
except NoResultFound:
raise ValueError(f"Tool with id {tool_id} not found.")
@enforce_types
@trace_method
async def delete_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> None:
"""Delete a tool by its ID."""
async with db_registry.async_session() as session:
try:
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
await tool.hard_delete_async(db_session=session, actor=actor)
except NoResultFound:
raise ValueError(f"Tool with id {tool_id} not found.")
@enforce_types
@trace_method
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: