feat: orm ToolsAgents migration (#2173)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2024-12-05 17:47:22 -08:00
committed by GitHub
parent 8c0aa76c4f
commit c18d6fd471
11 changed files with 369 additions and 7 deletions

1
.gitignore vendored
View File

@@ -551,6 +551,7 @@ tags
[._]*.un~
### VisualStudioCode ###
.vscode/
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json

View File

@@ -0,0 +1,44 @@
"""adding ToolsAgents ORM
Revision ID: 08b2f8225812
Revises: 3c683a662c82
Create Date: 2024-12-05 16:46:51.258831
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '08b2f8225812'
down_revision: Union[str, None] = '3c683a662c82'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tools_agents',
sa.Column('agent_id', sa.String(), nullable=False),
sa.Column('tool_id', sa.String(), nullable=False),
sa.Column('tool_name', sa.String(), nullable=False),
sa.Column('id', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
sa.Column('_created_by_id', sa.String(), nullable=True),
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ),
sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'),
sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'),
sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tools_agents')
# ### end Alembic commands ###

View File

@@ -56,7 +56,7 @@ def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]:
pause_heartbeats.__doc__ = pause_heartbeats_docstring
def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using case-insensitive string matching.
@@ -91,7 +91,7 @@ def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Opt
return results_str
def conversation_search_date(self: Agent, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
def conversation_search_date(self: "Agent", start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using a date range.
@@ -126,7 +126,7 @@ def conversation_search_date(self: Agent, start_date: str, end_date: str, page:
return results_str
def archival_memory_insert(self: Agent, content: str) -> Optional[str]:
def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
"""
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
@@ -140,7 +140,7 @@ def archival_memory_insert(self: Agent, content: str) -> Optional[str]:
return None
def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search archival memory using semantic (embedding-based) search.

View File

@@ -7,4 +7,5 @@ from letta.orm.organization import Organization
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.tool import Tool
from letta.orm.tools_agents import ToolsAgents
from letta.orm.user import User

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy import JSON, String, UniqueConstraint, event
from sqlalchemy.orm import Mapped, mapped_column, relationship
# TODO everything in functions should live in this model
@@ -11,6 +11,7 @@ from letta.schemas.tool import Tool as PydanticTool
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.tools_agents import ToolsAgents
class Tool(SqlalchemyBase, OrganizationMixin):
@@ -40,3 +41,23 @@ class Tool(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan")
# Add event listener to update tool_name in ToolsAgents when Tool name changes
@event.listens_for(Tool, 'before_update')
def update_tool_name_in_tools_agents(mapper, connection, target):
"""Update tool_name in ToolsAgents when Tool name changes."""
state = target._sa_instance_state
history = state.get_history('name', passive=True)
if not history.has_changes():
return
# Get the new name and update all associated ToolsAgents records
new_name = target.name
from letta.orm.tools_agents import ToolsAgents
connection.execute(
ToolsAgents.__table__.update().where(
ToolsAgents.tool_id == target.id
).values(tool_name=new_name)
)

32
letta/orm/tools_agents.py Normal file
View File

@@ -0,0 +1,32 @@
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
class ToolsAgents(SqlalchemyBase):
"""Agents can have one or many tools associated with them."""
__tablename__ = "tools_agents"
__pydantic_model__ = PydanticToolsAgents
__table_args__ = (
UniqueConstraint(
"agent_id",
"tool_name",
name="unique_tool_per_agent",
),
ForeignKeyConstraint(
["tool_id"],
["tools.id"],
name="fk_tool_id",
),
)
# Each agent must have unique tool names
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
tool_id: Mapped[str] = mapped_column(String, primary_key=True)
tool_name: Mapped[str] = mapped_column(String, primary_key=True)
# relationships
tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents")

View File

@@ -0,0 +1,32 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
class ToolsAgentsBase(LettaBase):
__id_prefix__ = "tools_agents"
class ToolsAgents(ToolsAgentsBase):
"""
Schema representing the relationship between tools and agents.
Parameters:
agent_id (str): The ID of the associated agent.
tool_id (str): The ID of the associated tool.
tool_name (str): The name of the tool.
created_at (datetime): The date this relationship was created.
updated_at (datetime): The date this relationship was last updated.
is_deleted (bool): Whether this tool-agent relationship is deleted or not.
"""
id: str = ToolsAgentsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
tool_id: str = Field(..., description="The ID of the associated tool.")
tool_name: str = Field(..., description="The name of the tool.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
is_deleted: bool = Field(False, description="Whether this tool-agent relationship is deleted or not.")

View File

@@ -77,6 +77,7 @@ from letta.schemas.user import User
from letta.services.agents_tags_manager import AgentsTagsManager
from letta.services.block_manager import BlockManager
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.services.tools_agents_manager import ToolsAgentsManager
from letta.services.job_manager import JobManager
from letta.services.organization_manager import OrganizationManager
from letta.services.per_agent_lock_manager import PerAgentLockManager
@@ -259,6 +260,7 @@ class SyncServer(Server):
self.agents_tags_manager = AgentsTagsManager()
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
self.blocks_agents_manager = BlocksAgentsManager()
self.tools_agents_manager = ToolsAgentsManager()
self.job_manager = JobManager()
# Managers that interface with parallelism

View File

@@ -130,7 +130,7 @@ class ToolManager:
with self.session_maker() as session:
try:
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
tool.delete(db_session=session, actor=actor)
tool.hard_delete(db_session=session, actor=actor)
except NoResultFound:
raise ValueError(f"Tool with id {tool_id} not found.")

View File

@@ -0,0 +1,94 @@
import warnings
from typing import List, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization
from letta.orm.tool import Tool
from letta.orm.tools_agents import ToolsAgents as ToolsAgentsModel
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
class ToolsAgentsManager:
"""Manages the relationship between tools and agents."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
def add_tool_to_agent(self, agent_id: str, tool_id: str, tool_name: str) -> PydanticToolsAgents:
"""Add a tool to an agent.
When a tool is added to an agent, it will be added to all agents in the same organization.
"""
with self.session_maker() as session:
try:
# Check if the tool-agent combination already exists for this agent
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
warnings.warn(f"Tool name '{tool_name}' already exists for agent '{agent_id}'.")
except NoResultFound:
tools_agents_record = PydanticToolsAgents(agent_id=agent_id, tool_id=tool_id, tool_name=tool_name)
tools_agents_record = ToolsAgentsModel(**tools_agents_record.model_dump(exclude_none=True))
tools_agents_record.create(session)
return tools_agents_record.to_pydantic()
def remove_tool_with_name_from_agent(self, agent_id: str, tool_name: str) -> None:
"""Remove a tool from an agent by its name.
When a tool is removed from an agent, it will be removed from all agents in the same organization.
"""
with self.session_maker() as session:
try:
# Find and delete the tool-agent association for the agent
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
tools_agents_record.hard_delete(session)
return tools_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
def remove_tool_with_id_from_agent(self, agent_id: str, tool_id: str) -> PydanticToolsAgents:
"""Remove a tool with an ID from an agent."""
with self.session_maker() as session:
try:
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_id=tool_id)
tools_agents_record.hard_delete(session)
return tools_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Tool ID '{tool_id}' not found for agent '{agent_id}'.")
def list_tool_ids_for_agent(self, agent_id: str) -> List[str]:
"""List all tool IDs associated with a specific agent."""
with self.session_maker() as session:
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.tool_id for record in tools_agents_record]
def list_tool_names_for_agent(self, agent_id: str) -> List[str]:
"""List all tool names associated with a specific agent."""
with self.session_maker() as session:
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.tool_name for record in tools_agents_record]
def list_agent_ids_with_tool(self, tool_id: str) -> List[str]:
"""List all agents associated with a specific tool."""
with self.session_maker() as session:
tools_agents_record = ToolsAgentsModel.list(db_session=session, tool_id=tool_id)
return [record.agent_id for record in tools_agents_record]
def get_tool_id_for_name(self, agent_id: str, tool_name: str) -> str:
"""Get the tool ID for a specific tool name for an agent."""
with self.session_maker() as session:
try:
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
return tools_agents_record.tool_id
except NoResultFound:
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
def remove_all_agent_tools(self, agent_id: str) -> None:
"""Remove all tools associated with an agent."""
with self.session_maker() as session:
tools_agents_records = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
for record in tools_agents_records:
record.hard_delete(session)

View File

@@ -16,6 +16,7 @@ from letta.orm import (
SandboxEnvironmentVariable,
Source,
Tool,
ToolsAgents,
User,
)
from letta.orm.agents_tags import AgentsTags
@@ -46,9 +47,10 @@ from letta.schemas.sandbox_config import (
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolUpdate
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.services.block_manager import BlockManager
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
from letta.settings import tool_settings
utils.DEBUG = True
@@ -76,6 +78,7 @@ def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Job))
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
session.execute(delete(BlocksAgents))
session.execute(delete(AgentsTags))
session.execute(delete(SandboxEnvironmentVariable))
@@ -240,6 +243,37 @@ def other_block(server: SyncServer, default_user):
yield block
@pytest.fixture
def other_tool(server: SyncServer, default_user, default_organization):
def print_other_tool(message: str):
"""
Args:
message (str): The message to print.
Returns:
str: The message that was printed.
"""
print(message)
return message
# Set up tool details
source_code = parse_source_code(print_other_tool)
source_type = "python"
description = "other_tool_description"
tags = ["test"]
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
tool = server.tool_manager.create_tool(tool, actor=default_user)
# Yield the created tool
yield tool
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
@@ -1155,6 +1189,107 @@ def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
# ======================================================================================================================
# ToolsAgentsManager Tests
# ======================================================================================================================
def test_add_tool_to_agent(server, sarah_agent, default_user, print_tool):
tool_association = server.tools_agents_manager.add_tool_to_agent(
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
)
assert tool_association.agent_id == sarah_agent.id
assert tool_association.tool_id == print_tool.id
assert tool_association.tool_name == print_tool.name
def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent, default_user, print_tool):
# Add the tool
tool_association = server.tools_agents_manager.add_tool_to_agent(
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
)
assert tool_association.tool_name == print_tool.name
# Change the tool name
new_name = "banana"
tool = server.tool_manager.update_tool_by_id(
tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user
)
assert tool.name == new_name
# Get the association
names = server.tools_agents_manager.list_tool_names_for_agent(agent_id=sarah_agent.id)
assert new_name in names
assert print_tool.name not in names
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.tools_agents_manager.add_tool_to_agent(
agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name"
)
def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
with pytest.warns(UserWarning, match=f"Tool name '{print_tool.name}' already exists for agent '{sarah_agent.id}'"):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=print_tool.name)
def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(
agent_id=sarah_agent.id, tool_name=print_tool.name
)
assert removed_tool.tool_name == print_tool.name
assert removed_tool.tool_id == print_tool.id
assert removed_tool.agent_id == sarah_agent.id
with pytest.raises(ValueError, match=f"Tool name '{print_tool.name}' not found for agent '{sarah_agent.id}'"):
server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name)
def test_list_tool_ids_for_agent(server, sarah_agent, default_user, print_tool, other_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
assert set(retrieved_tool_ids) == {print_tool.id, other_tool.id}
def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_user, print_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
server.tools_agents_manager.add_tool_to_agent(agent_id=charles_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
agent_ids = server.tools_agents_manager.list_agent_ids_with_tool(tool_id=print_tool.id)
assert sarah_agent.id in agent_ids
assert charles_agent.id in agent_ids
assert len(agent_ids) == 2
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
tool_manager = ToolManager()
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)
with pytest.raises(ForeignKeyConstraintViolationError):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
server.tools_agents_manager.remove_all_agent_tools(agent_id=sarah_agent.id)
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
assert not retrieved_tool_ids
# ======================================================================================================================
# JobManager Tests
# ======================================================================================================================