feat(asyncify): migrate tools (#2427)
This commit is contained in:
@@ -270,7 +270,7 @@ def list_agent_tools(
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/tools/attach/{tool_id}", response_model=AgentState, operation_id="attach_tool")
|
||||
def attach_tool(
|
||||
async def attach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -279,12 +279,12 @@ def attach_tool(
|
||||
"""
|
||||
Attach a tool to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/tools/detach/{tool_id}", response_model=AgentState, operation_id="detach_tool")
|
||||
def detach_tool(
|
||||
async def detach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -293,8 +293,8 @@ def detach_tool(
|
||||
"""
|
||||
Detach a tool from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")
|
||||
|
||||
@@ -42,7 +42,7 @@ def delete_tool(
|
||||
|
||||
|
||||
@router.get("/count", response_model=int, operation_id="count_tools")
|
||||
def count_tools(
|
||||
async def count_tools(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
include_base_tools: Optional[bool] = Query(False, description="Include built-in Letta tools in the count"),
|
||||
@@ -51,9 +51,8 @@ def count_tools(
|
||||
Get a count of all tools available to agents belonging to the org of the user.
|
||||
"""
|
||||
try:
|
||||
return server.tool_manager.size(
|
||||
actor=server.user_manager.get_user_or_default(user_id=actor_id), include_base_tools=include_base_tools
|
||||
)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.tool_manager.size_async(actor=actor, include_base_tools=include_base_tools)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -100,7 +99,7 @@ async def list_tools(
|
||||
|
||||
|
||||
@router.post("/", response_model=Tool, operation_id="create_tool")
|
||||
def create_tool(
|
||||
async def create_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@@ -109,9 +108,9 @@ def create_tool(
|
||||
Create a new tool
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
tool = Tool(**request.model_dump())
|
||||
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
|
||||
return await server.tool_manager.create_tool_async(pydantic_tool=tool, actor=actor)
|
||||
except UniqueConstraintViolationError as e:
|
||||
# Log or print the full exception here for debugging
|
||||
print(f"Error occurred: {e}")
|
||||
@@ -132,7 +131,7 @@ def create_tool(
|
||||
|
||||
|
||||
@router.put("/", response_model=Tool, operation_id="upsert_tool")
|
||||
def upsert_tool(
|
||||
async def upsert_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -141,8 +140,8 @@ def upsert_tool(
|
||||
Create or update a tool
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
tool = await server.tool_manager.create_or_update_tool_async(pydantic_tool=Tool(**request.model_dump()), actor=actor)
|
||||
return tool
|
||||
except UniqueConstraintViolationError as e:
|
||||
# Log the error and raise a conflict exception
|
||||
@@ -266,7 +265,7 @@ def list_composio_actions_by_app(
|
||||
|
||||
|
||||
@router.post("/composio/{composio_action_name}", response_model=Tool, operation_id="add_composio_tool")
|
||||
def add_composio_tool(
|
||||
async def add_composio_tool(
|
||||
composio_action_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -274,11 +273,11 @@ def add_composio_tool(
|
||||
"""
|
||||
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
tool_create = ToolCreate.from_composio(action_name=composio_action_name)
|
||||
return server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=actor)
|
||||
return await server.tool_manager.create_or_update_composio_tool_async(tool_create=tool_create, actor=actor)
|
||||
except ConnectedAccountNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
|
||||
@@ -73,6 +73,7 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
_apply_pagination_async,
|
||||
_apply_tag_filter,
|
||||
_process_relationship,
|
||||
_process_relationship_async,
|
||||
check_supports_structured_output,
|
||||
compile_system_message,
|
||||
derive_system_message,
|
||||
@@ -2385,6 +2386,42 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a tool to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to attach the tool to.
|
||||
tool_id: ID of the tool to attach.
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the agent or tool is not found.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Use the _process_relationship helper to attach the tool
|
||||
await _process_relationship_async(
|
||||
session=session,
|
||||
agent=agent,
|
||||
relationship_name="tools",
|
||||
model_class=ToolModel,
|
||||
item_ids=[tool_id],
|
||||
allow_partial=False, # Ensure the tool exists
|
||||
replace=False, # Extend the existing tools
|
||||
)
|
||||
|
||||
# Commit and refresh the agent
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
@@ -2419,6 +2456,40 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a tool from an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to detach the tool from.
|
||||
tool_id: ID of the tool to detach.
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the agent or tool is not found.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Filter out the tool to be detached
|
||||
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
|
||||
|
||||
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
|
||||
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
||||
|
||||
# Update the tools relationship
|
||||
agent.tools = remaining_tools
|
||||
|
||||
# Commit and refresh the agent
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
|
||||
|
||||
@@ -68,6 +68,50 @@ def _process_relationship(
|
||||
current_relationship.extend(new_items)
|
||||
|
||||
|
||||
@trace_method
|
||||
async def _process_relationship_async(
|
||||
session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
|
||||
):
|
||||
"""
|
||||
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
|
||||
|
||||
Args:
|
||||
session: The database session.
|
||||
agent: The AgentModel instance.
|
||||
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
|
||||
model_class: The ORM class corresponding to the related items.
|
||||
item_ids: List of IDs to set or update.
|
||||
allow_partial: If True, allows missing items without raising errors.
|
||||
replace: If True, replaces the entire relationship; otherwise, extends it.
|
||||
|
||||
Raises:
|
||||
ValueError: If `allow_partial` is False and some IDs are missing.
|
||||
"""
|
||||
current_relationship = getattr(agent, relationship_name, [])
|
||||
if not item_ids:
|
||||
if replace:
|
||||
setattr(agent, relationship_name, [])
|
||||
return
|
||||
|
||||
# Retrieve models for the provided IDs
|
||||
result = await session.execute(select(model_class).where(model_class.id.in_(item_ids)))
|
||||
found_items = result.scalars().all()
|
||||
|
||||
# Validate all items are found if allow_partial is False
|
||||
if not allow_partial and len(found_items) != len(item_ids):
|
||||
missing = set(item_ids) - {item.id for item in found_items}
|
||||
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
|
||||
|
||||
if replace:
|
||||
# Replace the relationship
|
||||
setattr(agent, relationship_name, found_items)
|
||||
else:
|
||||
# Extend the relationship (only add new items)
|
||||
current_ids = {item.id for item in current_relationship}
|
||||
new_items = [item for item in found_items if item.id not in current_ids]
|
||||
current_relationship.extend(new_items)
|
||||
|
||||
|
||||
def _process_tags(agent: AgentModel, tags: List[str], replace=True):
|
||||
"""
|
||||
Handles tags for an agent.
|
||||
|
||||
@@ -125,6 +125,13 @@ class ToolManager:
|
||||
PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_or_update_composio_tool_async(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
||||
return await self.create_or_update_tool_async(
|
||||
PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_or_update_langchain_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
||||
@@ -250,13 +257,13 @@ class ToolManager:
|
||||
logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
|
||||
logger.warning("Deleted tool: ")
|
||||
logger.warning(tool.pretty_print_columns())
|
||||
self.delete_tool_by_id(tool.id, actor=actor)
|
||||
await self.delete_tool_by_id_async(tool.id, actor=actor)
|
||||
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def size(
|
||||
async def size_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
include_base_tools: bool,
|
||||
@@ -266,10 +273,10 @@ class ToolManager:
|
||||
|
||||
If include_builtin is True, it will also count the built-in tools.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
if include_base_tools:
|
||||
return ToolModel.size(db_session=session, actor=actor)
|
||||
return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET)
|
||||
return await ToolModel.size_async(db_session=session, actor=actor)
|
||||
return await ToolModel.size_async(db_session=session, actor=actor, name=LETTA_TOOL_SET)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -341,6 +348,17 @@ class ToolManager:
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a tool by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
||||
await tool.hard_delete_async(db_session=session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
|
||||
|
||||
Reference in New Issue
Block a user