feat: orm ToolsAgents migration (#2173)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -551,6 +551,7 @@ tags
|
|||||||
[._]*.un~
|
[._]*.un~
|
||||||
|
|
||||||
### VisualStudioCode ###
|
### VisualStudioCode ###
|
||||||
|
.vscode/
|
||||||
.vscode/*
|
.vscode/*
|
||||||
!.vscode/settings.json
|
!.vscode/settings.json
|
||||||
!.vscode/tasks.json
|
!.vscode/tasks.json
|
||||||
|
|||||||
44
alembic/versions/08b2f8225812_adding_toolsagents_orm.py
Normal file
44
alembic/versions/08b2f8225812_adding_toolsagents_orm.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""adding ToolsAgents ORM
|
||||||
|
|
||||||
|
Revision ID: 08b2f8225812
|
||||||
|
Revises: 3c683a662c82
|
||||||
|
Create Date: 2024-12-05 16:46:51.258831
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '08b2f8225812'
|
||||||
|
down_revision: Union[str, None] = '3c683a662c82'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('tools_agents',
|
||||||
|
sa.Column('agent_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('tool_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('tool_name', sa.String(), nullable=False),
|
||||||
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||||
|
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'),
|
||||||
|
sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'),
|
||||||
|
sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('tools_agents')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -56,7 +56,7 @@ def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]:
|
|||||||
pause_heartbeats.__doc__ = pause_heartbeats_docstring
|
pause_heartbeats.__doc__ = pause_heartbeats_docstring
|
||||||
|
|
||||||
|
|
||||||
def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
|
def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Search prior conversation history using case-insensitive string matching.
|
Search prior conversation history using case-insensitive string matching.
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Opt
|
|||||||
return results_str
|
return results_str
|
||||||
|
|
||||||
|
|
||||||
def conversation_search_date(self: Agent, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
|
def conversation_search_date(self: "Agent", start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Search prior conversation history using a date range.
|
Search prior conversation history using a date range.
|
||||||
|
|
||||||
@@ -126,7 +126,7 @@ def conversation_search_date(self: Agent, start_date: str, end_date: str, page:
|
|||||||
return results_str
|
return results_str
|
||||||
|
|
||||||
|
|
||||||
def archival_memory_insert(self: Agent, content: str) -> Optional[str]:
|
def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ def archival_memory_insert(self: Agent, content: str) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
|
def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Search archival memory using semantic (embedding-based) search.
|
Search archival memory using semantic (embedding-based) search.
|
||||||
|
|
||||||
|
|||||||
@@ -7,4 +7,5 @@ from letta.orm.organization import Organization
|
|||||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||||
from letta.orm.source import Source
|
from letta.orm.source import Source
|
||||||
from letta.orm.tool import Tool
|
from letta.orm.tool import Tool
|
||||||
|
from letta.orm.tools_agents import ToolsAgents
|
||||||
from letta.orm.user import User
|
from letta.orm.user import User
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import JSON, String, UniqueConstraint
|
from sqlalchemy import JSON, String, UniqueConstraint, event
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
# TODO everything in functions should live in this model
|
# TODO everything in functions should live in this model
|
||||||
@@ -11,6 +11,7 @@ from letta.schemas.tool import Tool as PydanticTool
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
|
from letta.orm.tools_agents import ToolsAgents
|
||||||
|
|
||||||
|
|
||||||
class Tool(SqlalchemyBase, OrganizationMixin):
|
class Tool(SqlalchemyBase, OrganizationMixin):
|
||||||
@@ -40,3 +41,23 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
|||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
||||||
|
tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
|
||||||
|
# Add event listener to update tool_name in ToolsAgents when Tool name changes
|
||||||
|
@event.listens_for(Tool, 'before_update')
|
||||||
|
def update_tool_name_in_tools_agents(mapper, connection, target):
|
||||||
|
"""Update tool_name in ToolsAgents when Tool name changes."""
|
||||||
|
state = target._sa_instance_state
|
||||||
|
history = state.get_history('name', passive=True)
|
||||||
|
if not history.has_changes():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the new name and update all associated ToolsAgents records
|
||||||
|
new_name = target.name
|
||||||
|
from letta.orm.tools_agents import ToolsAgents
|
||||||
|
connection.execute(
|
||||||
|
ToolsAgents.__table__.update().where(
|
||||||
|
ToolsAgents.tool_id == target.id
|
||||||
|
).values(tool_name=new_name)
|
||||||
|
)
|
||||||
|
|||||||
32
letta/orm/tools_agents.py
Normal file
32
letta/orm/tools_agents.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
|
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
|
||||||
|
|
||||||
|
|
||||||
|
class ToolsAgents(SqlalchemyBase):
|
||||||
|
"""Agents can have one or many tools associated with them."""
|
||||||
|
|
||||||
|
__tablename__ = "tools_agents"
|
||||||
|
__pydantic_model__ = PydanticToolsAgents
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"agent_id",
|
||||||
|
"tool_name",
|
||||||
|
name="unique_tool_per_agent",
|
||||||
|
),
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
["tool_id"],
|
||||||
|
["tools.id"],
|
||||||
|
name="fk_tool_id",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each agent must have unique tool names
|
||||||
|
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||||
|
tool_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
tool_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents")
|
||||||
32
letta/schemas/tools_agents.py
Normal file
32
letta/schemas/tools_agents.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from letta.schemas.letta_base import LettaBase
|
||||||
|
|
||||||
|
|
||||||
|
class ToolsAgentsBase(LettaBase):
|
||||||
|
__id_prefix__ = "tools_agents"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolsAgents(ToolsAgentsBase):
|
||||||
|
"""
|
||||||
|
Schema representing the relationship between tools and agents.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
agent_id (str): The ID of the associated agent.
|
||||||
|
tool_id (str): The ID of the associated tool.
|
||||||
|
tool_name (str): The name of the tool.
|
||||||
|
created_at (datetime): The date this relationship was created.
|
||||||
|
updated_at (datetime): The date this relationship was last updated.
|
||||||
|
is_deleted (bool): Whether this tool-agent relationship is deleted or not.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str = ToolsAgentsBase.generate_id_field()
|
||||||
|
agent_id: str = Field(..., description="The ID of the associated agent.")
|
||||||
|
tool_id: str = Field(..., description="The ID of the associated tool.")
|
||||||
|
tool_name: str = Field(..., description="The name of the tool.")
|
||||||
|
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
|
||||||
|
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
|
||||||
|
is_deleted: bool = Field(False, description="Whether this tool-agent relationship is deleted or not.")
|
||||||
@@ -77,6 +77,7 @@ from letta.schemas.user import User
|
|||||||
from letta.services.agents_tags_manager import AgentsTagsManager
|
from letta.services.agents_tags_manager import AgentsTagsManager
|
||||||
from letta.services.block_manager import BlockManager
|
from letta.services.block_manager import BlockManager
|
||||||
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||||
|
from letta.services.tools_agents_manager import ToolsAgentsManager
|
||||||
from letta.services.job_manager import JobManager
|
from letta.services.job_manager import JobManager
|
||||||
from letta.services.organization_manager import OrganizationManager
|
from letta.services.organization_manager import OrganizationManager
|
||||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||||
@@ -259,6 +260,7 @@ class SyncServer(Server):
|
|||||||
self.agents_tags_manager = AgentsTagsManager()
|
self.agents_tags_manager = AgentsTagsManager()
|
||||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||||
self.blocks_agents_manager = BlocksAgentsManager()
|
self.blocks_agents_manager = BlocksAgentsManager()
|
||||||
|
self.tools_agents_manager = ToolsAgentsManager()
|
||||||
self.job_manager = JobManager()
|
self.job_manager = JobManager()
|
||||||
|
|
||||||
# Managers that interface with parallelism
|
# Managers that interface with parallelism
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class ToolManager:
|
|||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
try:
|
try:
|
||||||
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
||||||
tool.delete(db_session=session, actor=actor)
|
tool.hard_delete(db_session=session, actor=actor)
|
||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||||
|
|
||||||
|
|||||||
94
letta/services/tools_agents_manager.py
Normal file
94
letta/services/tools_agents_manager.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from letta.orm.errors import NoResultFound
|
||||||
|
from letta.orm.organization import Organization
|
||||||
|
from letta.orm.tool import Tool
|
||||||
|
from letta.orm.tools_agents import ToolsAgents as ToolsAgentsModel
|
||||||
|
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
|
||||||
|
|
||||||
|
class ToolsAgentsManager:
|
||||||
|
"""Manages the relationship between tools and agents."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from letta.server.server import db_context
|
||||||
|
self.session_maker = db_context
|
||||||
|
|
||||||
|
def add_tool_to_agent(self, agent_id: str, tool_id: str, tool_name: str) -> PydanticToolsAgents:
|
||||||
|
"""Add a tool to an agent.
|
||||||
|
|
||||||
|
When a tool is added to an agent, it will be added to all agents in the same organization.
|
||||||
|
"""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
try:
|
||||||
|
# Check if the tool-agent combination already exists for this agent
|
||||||
|
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||||
|
warnings.warn(f"Tool name '{tool_name}' already exists for agent '{agent_id}'.")
|
||||||
|
except NoResultFound:
|
||||||
|
tools_agents_record = PydanticToolsAgents(agent_id=agent_id, tool_id=tool_id, tool_name=tool_name)
|
||||||
|
tools_agents_record = ToolsAgentsModel(**tools_agents_record.model_dump(exclude_none=True))
|
||||||
|
tools_agents_record.create(session)
|
||||||
|
|
||||||
|
return tools_agents_record.to_pydantic()
|
||||||
|
|
||||||
|
def remove_tool_with_name_from_agent(self, agent_id: str, tool_name: str) -> None:
|
||||||
|
"""Remove a tool from an agent by its name.
|
||||||
|
|
||||||
|
When a tool is removed from an agent, it will be removed from all agents in the same organization.
|
||||||
|
"""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
try:
|
||||||
|
# Find and delete the tool-agent association for the agent
|
||||||
|
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||||
|
tools_agents_record.hard_delete(session)
|
||||||
|
return tools_agents_record.to_pydantic()
|
||||||
|
except NoResultFound:
|
||||||
|
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
|
||||||
|
|
||||||
|
def remove_tool_with_id_from_agent(self, agent_id: str, tool_id: str) -> PydanticToolsAgents:
|
||||||
|
"""Remove a tool with an ID from an agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
try:
|
||||||
|
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_id=tool_id)
|
||||||
|
tools_agents_record.hard_delete(session)
|
||||||
|
return tools_agents_record.to_pydantic()
|
||||||
|
except NoResultFound:
|
||||||
|
raise ValueError(f"Tool ID '{tool_id}' not found for agent '{agent_id}'.")
|
||||||
|
|
||||||
|
def list_tool_ids_for_agent(self, agent_id: str) -> List[str]:
|
||||||
|
"""List all tool IDs associated with a specific agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||||
|
return [record.tool_id for record in tools_agents_record]
|
||||||
|
|
||||||
|
def list_tool_names_for_agent(self, agent_id: str) -> List[str]:
|
||||||
|
"""List all tool names associated with a specific agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||||
|
return [record.tool_name for record in tools_agents_record]
|
||||||
|
|
||||||
|
def list_agent_ids_with_tool(self, tool_id: str) -> List[str]:
|
||||||
|
"""List all agents associated with a specific tool."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
tools_agents_record = ToolsAgentsModel.list(db_session=session, tool_id=tool_id)
|
||||||
|
return [record.agent_id for record in tools_agents_record]
|
||||||
|
|
||||||
|
def get_tool_id_for_name(self, agent_id: str, tool_name: str) -> str:
|
||||||
|
"""Get the tool ID for a specific tool name for an agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
try:
|
||||||
|
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||||
|
return tools_agents_record.tool_id
|
||||||
|
except NoResultFound:
|
||||||
|
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
|
||||||
|
|
||||||
|
def remove_all_agent_tools(self, agent_id: str) -> None:
|
||||||
|
"""Remove all tools associated with an agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
tools_agents_records = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||||
|
for record in tools_agents_records:
|
||||||
|
record.hard_delete(session)
|
||||||
@@ -16,6 +16,7 @@ from letta.orm import (
|
|||||||
SandboxEnvironmentVariable,
|
SandboxEnvironmentVariable,
|
||||||
Source,
|
Source,
|
||||||
Tool,
|
Tool,
|
||||||
|
ToolsAgents,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from letta.orm.agents_tags import AgentsTags
|
from letta.orm.agents_tags import AgentsTags
|
||||||
@@ -46,9 +47,10 @@ from letta.schemas.sandbox_config import (
|
|||||||
from letta.schemas.source import Source as PydanticSource
|
from letta.schemas.source import Source as PydanticSource
|
||||||
from letta.schemas.source import SourceUpdate
|
from letta.schemas.source import SourceUpdate
|
||||||
from letta.schemas.tool import Tool as PydanticTool
|
from letta.schemas.tool import Tool as PydanticTool
|
||||||
from letta.schemas.tool import ToolUpdate
|
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||||
from letta.services.block_manager import BlockManager
|
from letta.services.block_manager import BlockManager
|
||||||
from letta.services.organization_manager import OrganizationManager
|
from letta.services.organization_manager import OrganizationManager
|
||||||
|
from letta.services.tool_manager import ToolManager
|
||||||
from letta.settings import tool_settings
|
from letta.settings import tool_settings
|
||||||
|
|
||||||
utils.DEBUG = True
|
utils.DEBUG = True
|
||||||
@@ -76,6 +78,7 @@ def clear_tables(server: SyncServer):
|
|||||||
"""Fixture to clear the organization table before each test."""
|
"""Fixture to clear the organization table before each test."""
|
||||||
with server.organization_manager.session_maker() as session:
|
with server.organization_manager.session_maker() as session:
|
||||||
session.execute(delete(Job))
|
session.execute(delete(Job))
|
||||||
|
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||||
session.execute(delete(BlocksAgents))
|
session.execute(delete(BlocksAgents))
|
||||||
session.execute(delete(AgentsTags))
|
session.execute(delete(AgentsTags))
|
||||||
session.execute(delete(SandboxEnvironmentVariable))
|
session.execute(delete(SandboxEnvironmentVariable))
|
||||||
@@ -240,6 +243,37 @@ def other_block(server: SyncServer, default_user):
|
|||||||
yield block
|
yield block
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def other_tool(server: SyncServer, default_user, default_organization):
|
||||||
|
def print_other_tool(message: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
message (str): The message to print.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The message that was printed.
|
||||||
|
"""
|
||||||
|
print(message)
|
||||||
|
return message
|
||||||
|
|
||||||
|
# Set up tool details
|
||||||
|
source_code = parse_source_code(print_other_tool)
|
||||||
|
source_type = "python"
|
||||||
|
description = "other_tool_description"
|
||||||
|
tags = ["test"]
|
||||||
|
|
||||||
|
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||||
|
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||||
|
|
||||||
|
derived_name = derived_json_schema["name"]
|
||||||
|
tool.json_schema = derived_json_schema
|
||||||
|
tool.name = derived_name
|
||||||
|
|
||||||
|
tool = server.tool_manager.create_tool(tool, actor=default_user)
|
||||||
|
|
||||||
|
# Yield the created tool
|
||||||
|
yield tool
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
config = LettaConfig.load()
|
config = LettaConfig.load()
|
||||||
@@ -1155,6 +1189,107 @@ def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user
|
|||||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================================================================
|
||||||
|
# ToolsAgentsManager Tests
|
||||||
|
# ======================================================================================================================
|
||||||
|
def test_add_tool_to_agent(server, sarah_agent, default_user, print_tool):
|
||||||
|
tool_association = server.tools_agents_manager.add_tool_to_agent(
|
||||||
|
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_association.agent_id == sarah_agent.id
|
||||||
|
assert tool_association.tool_id == print_tool.id
|
||||||
|
assert tool_association.tool_name == print_tool.name
|
||||||
|
|
||||||
|
|
||||||
|
def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent, default_user, print_tool):
|
||||||
|
# Add the tool
|
||||||
|
tool_association = server.tools_agents_manager.add_tool_to_agent(
|
||||||
|
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
|
||||||
|
)
|
||||||
|
assert tool_association.tool_name == print_tool.name
|
||||||
|
|
||||||
|
# Change the tool name
|
||||||
|
new_name = "banana"
|
||||||
|
tool = server.tool_manager.update_tool_by_id(
|
||||||
|
tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user
|
||||||
|
)
|
||||||
|
assert tool.name == new_name
|
||||||
|
|
||||||
|
# Get the association
|
||||||
|
names = server.tools_agents_manager.list_tool_names_for_agent(agent_id=sarah_agent.id)
|
||||||
|
assert new_name in names
|
||||||
|
assert print_tool.name not in names
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||||
|
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
|
||||||
|
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(
|
||||||
|
agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool):
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning, match=f"Tool name '{print_tool.name}' already exists for agent '{sarah_agent.id}'"):
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=print_tool.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool):
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||||
|
|
||||||
|
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(
|
||||||
|
agent_id=sarah_agent.id, tool_name=print_tool.name
|
||||||
|
)
|
||||||
|
|
||||||
|
assert removed_tool.tool_name == print_tool.name
|
||||||
|
assert removed_tool.tool_id == print_tool.id
|
||||||
|
assert removed_tool.agent_id == sarah_agent.id
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=f"Tool name '{print_tool.name}' not found for agent '{sarah_agent.id}'"):
|
||||||
|
server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_tool_ids_for_agent(server, sarah_agent, default_user, print_tool, other_tool):
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
|
||||||
|
|
||||||
|
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
|
||||||
|
|
||||||
|
assert set(retrieved_tool_ids) == {print_tool.id, other_tool.id}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_user, print_tool):
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=charles_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||||
|
|
||||||
|
agent_ids = server.tools_agents_manager.list_agent_ids_with_tool(tool_id=print_tool.id)
|
||||||
|
|
||||||
|
assert sarah_agent.id in agent_ids
|
||||||
|
assert charles_agent.id in agent_ids
|
||||||
|
assert len(agent_ids) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||||
|
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
|
||||||
|
tool_manager = ToolManager()
|
||||||
|
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)
|
||||||
|
|
||||||
|
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||||
|
|
||||||
|
def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool):
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||||
|
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
|
||||||
|
|
||||||
|
server.tools_agents_manager.remove_all_agent_tools(agent_id=sarah_agent.id)
|
||||||
|
|
||||||
|
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
|
||||||
|
|
||||||
|
assert not retrieved_tool_ids
|
||||||
|
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
# JobManager Tests
|
# JobManager Tests
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user