From c18d6fd471f1d3ce09f08dd61d8a2cff436e3e2b Mon Sep 17 00:00:00 2001 From: mlong93 <35275280+mlong93@users.noreply.github.com> Date: Thu, 5 Dec 2024 17:47:22 -0800 Subject: [PATCH] feat: orm ToolsAgents migration (#2173) Co-authored-by: Mindy Long --- .gitignore | 1 + .../08b2f8225812_adding_toolsagents_orm.py | 44 ++++++ letta/functions/function_sets/base.py | 8 +- letta/orm/__init__.py | 1 + letta/orm/tool.py | 23 ++- letta/orm/tools_agents.py | 32 ++++ letta/schemas/tools_agents.py | 32 ++++ letta/server/server.py | 2 + letta/services/tool_manager.py | 2 +- letta/services/tools_agents_manager.py | 94 ++++++++++++ tests/test_managers.py | 137 +++++++++++++++++- 11 files changed, 369 insertions(+), 7 deletions(-) create mode 100644 alembic/versions/08b2f8225812_adding_toolsagents_orm.py create mode 100644 letta/orm/tools_agents.py create mode 100644 letta/schemas/tools_agents.py create mode 100644 letta/services/tools_agents_manager.py diff --git a/.gitignore b/.gitignore index 9fb91c2b..12042451 100644 --- a/.gitignore +++ b/.gitignore @@ -551,6 +551,7 @@ tags [._]*.un~ ### VisualStudioCode ### +.vscode/ .vscode/* !.vscode/settings.json !.vscode/tasks.json diff --git a/alembic/versions/08b2f8225812_adding_toolsagents_orm.py b/alembic/versions/08b2f8225812_adding_toolsagents_orm.py new file mode 100644 index 00000000..902225ab --- /dev/null +++ b/alembic/versions/08b2f8225812_adding_toolsagents_orm.py @@ -0,0 +1,44 @@ +"""adding ToolsAgents ORM + +Revision ID: 08b2f8225812 +Revises: 3c683a662c82 +Create Date: 2024-12-05 16:46:51.258831 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '08b2f8225812' +down_revision: Union[str, None] = '3c683a662c82' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tools_agents', + sa.Column('agent_id', sa.String(), nullable=False), + sa.Column('tool_id', sa.String(), nullable=False), + sa.Column('tool_name', sa.String(), nullable=False), + sa.Column('id', sa.String(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), + sa.Column('_created_by_id', sa.String(), nullable=True), + sa.Column('_last_updated_by_id', sa.String(), nullable=True), + sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ), + sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'), + sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'), + sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tools_agents') + # ### end Alembic commands ### diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index e7bd4a9d..6e963128 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -56,7 +56,7 @@ def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]: pause_heartbeats.__doc__ = pause_heartbeats_docstring -def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]: +def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]: """ Search prior conversation history using case-insensitive string matching. @@ -91,7 +91,7 @@ def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Opt return results_str -def conversation_search_date(self: Agent, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]: +def conversation_search_date(self: "Agent", start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]: """ Search prior conversation history using a date range. @@ -126,7 +126,7 @@ def conversation_search_date(self: Agent, start_date: str, end_date: str, page: return results_str -def archival_memory_insert(self: Agent, content: str) -> Optional[str]: +def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: """ Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. @@ -140,7 +140,7 @@ def archival_memory_insert(self: Agent, content: str) -> Optional[str]: return None -def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]: +def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]: """ Search archival memory using semantic (embedding-based) search. diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 42988112..8d47ba45 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -7,4 +7,5 @@ from letta.orm.organization import Organization from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.tool import Tool +from letta.orm.tools_agents import ToolsAgents from letta.orm.user import User diff --git a/letta/orm/tool.py b/letta/orm/tool.py index d86fffa2..00038fe0 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, String, UniqueConstraint +from sqlalchemy import JSON, String, UniqueConstraint, event from sqlalchemy.orm import Mapped, mapped_column, relationship # TODO everything in functions should live in this model @@ -11,6 +11,7 @@ from letta.schemas.tool import Tool as PydanticTool if TYPE_CHECKING: from letta.orm.organization import Organization + from letta.orm.tools_agents import ToolsAgents class Tool(SqlalchemyBase, OrganizationMixin): @@ -40,3 +41,23 @@ class Tool(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin") + tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan") + + +# Add event listener to update tool_name in ToolsAgents when Tool name changes +@event.listens_for(Tool, 'before_update') +def update_tool_name_in_tools_agents(mapper, connection, target): + """Update tool_name in ToolsAgents when Tool name changes.""" + state = target._sa_instance_state + history = state.get_history('name', passive=True) + if not history.has_changes(): + return + + # Get the new name and update all associated ToolsAgents records + new_name = target.name + from letta.orm.tools_agents import ToolsAgents + connection.execute( + ToolsAgents.__table__.update().where( + ToolsAgents.tool_id == target.id + ).values(tool_name=new_name) + ) diff --git a/letta/orm/tools_agents.py b/letta/orm/tools_agents.py new file mode 100644 index 00000000..dfb8a9a7 --- /dev/null +++ b/letta/orm/tools_agents.py @@ -0,0 +1,32 @@ +from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents + + +class ToolsAgents(SqlalchemyBase): + """Agents can have one or many tools associated with them.""" + + __tablename__ = "tools_agents" + __pydantic_model__ = PydanticToolsAgents + __table_args__ = ( + UniqueConstraint( + "agent_id", + "tool_name", + name="unique_tool_per_agent", + ), + ForeignKeyConstraint( + ["tool_id"], + ["tools.id"], + name="fk_tool_id", + ), + ) + + # Each agent must have unique tool names + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) + tool_id: Mapped[str] = mapped_column(String, primary_key=True) + tool_name: Mapped[str] = mapped_column(String, primary_key=True) + + # relationships + tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents") diff --git a/letta/schemas/tools_agents.py b/letta/schemas/tools_agents.py new file mode 100644 index 00000000..b7e8bdcf --- /dev/null +++ b/letta/schemas/tools_agents.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from letta.schemas.letta_base import LettaBase + + +class ToolsAgentsBase(LettaBase): + __id_prefix__ = "tools_agents" + + +class ToolsAgents(ToolsAgentsBase): + """ + Schema representing the relationship between tools and agents. + + Parameters: + agent_id (str): The ID of the associated agent. + tool_id (str): The ID of the associated tool. + tool_name (str): The name of the tool. + created_at (datetime): The date this relationship was created. + updated_at (datetime): The date this relationship was last updated. + is_deleted (bool): Whether this tool-agent relationship is deleted or not. + """ + + id: str = ToolsAgentsBase.generate_id_field() + agent_id: str = Field(..., description="The ID of the associated agent.") + tool_id: str = Field(..., description="The ID of the associated tool.") + tool_name: str = Field(..., description="The name of the tool.") + created_at: Optional[datetime] = Field(None, description="The creation date of the association.") + updated_at: Optional[datetime] = Field(None, description="The update date of the association.") + is_deleted: bool = Field(False, description="Whether this tool-agent relationship is deleted or not.") diff --git a/letta/server/server.py b/letta/server/server.py index 23e5f211..d12e0f3b 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -77,6 +77,7 @@ from letta.schemas.user import User from letta.services.agents_tags_manager import AgentsTagsManager from letta.services.block_manager import BlockManager from letta.services.blocks_agents_manager import BlocksAgentsManager +from letta.services.tools_agents_manager import ToolsAgentsManager from letta.services.job_manager import JobManager from letta.services.organization_manager import OrganizationManager from letta.services.per_agent_lock_manager import PerAgentLockManager @@ -259,6 +260,7 @@ class SyncServer(Server): self.agents_tags_manager = AgentsTagsManager() self.sandbox_config_manager = SandboxConfigManager(tool_settings) self.blocks_agents_manager = BlocksAgentsManager() + self.tools_agents_manager = ToolsAgentsManager() self.job_manager = JobManager() # Managers that interface with parallelism diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 7acf4fa5..33b1afd7 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -130,7 +130,7 @@ class ToolManager: with self.session_maker() as session: try: tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) - tool.delete(db_session=session, actor=actor) + tool.hard_delete(db_session=session, actor=actor) except NoResultFound: raise ValueError(f"Tool with id {tool_id} not found.") diff --git a/letta/services/tools_agents_manager.py b/letta/services/tools_agents_manager.py new file mode 100644 index 00000000..35b24e5a --- /dev/null +++ b/letta/services/tools_agents_manager.py @@ -0,0 +1,94 @@ +import warnings +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from letta.orm.errors import NoResultFound +from letta.orm.organization import Organization +from letta.orm.tool import Tool +from letta.orm.tools_agents import ToolsAgents as ToolsAgentsModel +from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents + +class ToolsAgentsManager: + """Manages the relationship between tools and agents.""" + + def __init__(self): + from letta.server.server import db_context + self.session_maker = db_context + + def add_tool_to_agent(self, agent_id: str, tool_id: str, tool_name: str) -> PydanticToolsAgents: + """Add a tool to an agent. + + When a tool is added to an agent, it will be added to all agents in the same organization. + """ + with self.session_maker() as session: + try: + # Check if the tool-agent combination already exists for this agent + tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name) + warnings.warn(f"Tool name '{tool_name}' already exists for agent '{agent_id}'.") + except NoResultFound: + tools_agents_record = PydanticToolsAgents(agent_id=agent_id, tool_id=tool_id, tool_name=tool_name) + tools_agents_record = ToolsAgentsModel(**tools_agents_record.model_dump(exclude_none=True)) + tools_agents_record.create(session) + + return tools_agents_record.to_pydantic() + + def remove_tool_with_name_from_agent(self, agent_id: str, tool_name: str) -> None: + """Remove a tool from an agent by its name. + + When a tool is removed from an agent, it will be removed from all agents in the same organization. + """ + with self.session_maker() as session: + try: + # Find and delete the tool-agent association for the agent + tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name) + tools_agents_record.hard_delete(session) + return tools_agents_record.to_pydantic() + except NoResultFound: + raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.") + + def remove_tool_with_id_from_agent(self, agent_id: str, tool_id: str) -> PydanticToolsAgents: + """Remove a tool with an ID from an agent.""" + with self.session_maker() as session: + try: + tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_id=tool_id) + tools_agents_record.hard_delete(session) + return tools_agents_record.to_pydantic() + except NoResultFound: + raise ValueError(f"Tool ID '{tool_id}' not found for agent '{agent_id}'.") + + def list_tool_ids_for_agent(self, agent_id: str) -> List[str]: + """List all tool IDs associated with a specific agent.""" + with self.session_maker() as session: + tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id) + return [record.tool_id for record in tools_agents_record] + + def list_tool_names_for_agent(self, agent_id: str) -> List[str]: + """List all tool names associated with a specific agent.""" + with self.session_maker() as session: + tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id) + return [record.tool_name for record in tools_agents_record] + + def list_agent_ids_with_tool(self, tool_id: str) -> List[str]: + """List all agents associated with a specific tool.""" + with self.session_maker() as session: + tools_agents_record = ToolsAgentsModel.list(db_session=session, tool_id=tool_id) + return [record.agent_id for record in tools_agents_record] + + def get_tool_id_for_name(self, agent_id: str, tool_name: str) -> str: + """Get the tool ID for a specific tool name for an agent.""" + with self.session_maker() as session: + try: + tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name) + return tools_agents_record.tool_id + except NoResultFound: + raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.") + + def remove_all_agent_tools(self, agent_id: str) -> None: + """Remove all tools associated with an agent.""" + with self.session_maker() as session: + tools_agents_records = ToolsAgentsModel.list(db_session=session, agent_id=agent_id) + for record in tools_agents_records: + record.hard_delete(session) \ No newline at end of file diff --git a/tests/test_managers.py b/tests/test_managers.py index 7b5f86f2..20022ccc 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -16,6 +16,7 @@ from letta.orm import ( SandboxEnvironmentVariable, Source, Tool, + ToolsAgents, User, ) from letta.orm.agents_tags import AgentsTags @@ -46,9 +47,10 @@ from letta.schemas.sandbox_config import ( from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate from letta.schemas.tool import Tool as PydanticTool -from letta.schemas.tool import ToolUpdate +from letta.schemas.tool import ToolCreate, ToolUpdate from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager +from letta.services.tool_manager import ToolManager from letta.settings import tool_settings utils.DEBUG = True @@ -76,6 +78,7 @@ def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: session.execute(delete(Job)) + session.execute(delete(ToolsAgents)) # Clear ToolsAgents first session.execute(delete(BlocksAgents)) session.execute(delete(AgentsTags)) session.execute(delete(SandboxEnvironmentVariable)) @@ -240,6 +243,37 @@ def other_block(server: SyncServer, default_user): yield block +@pytest.fixture +def other_tool(server: SyncServer, default_user, default_organization): + def print_other_tool(message: str): + """ + Args: + message (str): The message to print. + + Returns: + str: The message that was printed. + """ + print(message) + return message + + # Set up tool details + source_code = parse_source_code(print_other_tool) + source_type = "python" + description = "other_tool_description" + tags = ["test"] + + tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + + derived_name = derived_json_schema["name"] + tool.json_schema = derived_json_schema + tool.name = derived_name + + tool = server.tool_manager.create_tool(tool, actor=default_user) + + # Yield the created tool + yield tool + @pytest.fixture(scope="module") def server(): config = LettaConfig.load() @@ -1155,6 +1189,107 @@ def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) +# ====================================================================================================================== +# ToolsAgentsManager Tests +# ====================================================================================================================== +def test_add_tool_to_agent(server, sarah_agent, default_user, print_tool): + tool_association = server.tools_agents_manager.add_tool_to_agent( + agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name + ) + + assert tool_association.agent_id == sarah_agent.id + assert tool_association.tool_id == print_tool.id + assert tool_association.tool_name == print_tool.name + + +def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent, default_user, print_tool): + # Add the tool + tool_association = server.tools_agents_manager.add_tool_to_agent( + agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name + ) + assert tool_association.tool_name == print_tool.name + + # Change the tool name + new_name = "banana" + tool = server.tool_manager.update_tool_by_id( + tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user + ) + assert tool.name == new_name + + # Get the association + names = server.tools_agents_manager.list_tool_names_for_agent(agent_id=sarah_agent.id) + assert new_name in names + assert print_tool.name not in names + + +@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite") +def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user): + with pytest.raises(ForeignKeyConstraintViolationError): + server.tools_agents_manager.add_tool_to_agent( + agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name" + ) + + +def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool): + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) + + with pytest.warns(UserWarning, match=f"Tool name '{print_tool.name}' already exists for agent '{sarah_agent.id}'"): + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=print_tool.name) + + +def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool): + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) + + removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent( + agent_id=sarah_agent.id, tool_name=print_tool.name + ) + + assert removed_tool.tool_name == print_tool.name + assert removed_tool.tool_id == print_tool.id + assert removed_tool.agent_id == sarah_agent.id + + with pytest.raises(ValueError, match=f"Tool name '{print_tool.name}' not found for agent '{sarah_agent.id}'"): + server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name) + + +def test_list_tool_ids_for_agent(server, sarah_agent, default_user, print_tool, other_tool): + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name) + + retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id) + + assert set(retrieved_tool_ids) == {print_tool.id, other_tool.id} + + +def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_user, print_tool): + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) + server.tools_agents_manager.add_tool_to_agent(agent_id=charles_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) + + agent_ids = server.tools_agents_manager.list_agent_ids_with_tool(tool_id=print_tool.id) + + assert sarah_agent.id in agent_ids + assert charles_agent.id in agent_ids + assert len(agent_ids) == 2 + + +@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite") +def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool): + tool_manager = ToolManager() + tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user) + + with pytest.raises(ForeignKeyConstraintViolationError): + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) + +def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool): + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name) + server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name) + + server.tools_agents_manager.remove_all_agent_tools(agent_id=sarah_agent.id) + + retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id) + + assert not retrieved_tool_ids + # ====================================================================================================================== # JobManager Tests # ======================================================================================================================