feat: orm ToolsAgents migration (#2173)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2024-12-05 17:47:22 -08:00
committed by GitHub
parent 8c0aa76c4f
commit c18d6fd471
11 changed files with 369 additions and 7 deletions

View File

@@ -16,6 +16,7 @@ from letta.orm import (
SandboxEnvironmentVariable,
Source,
Tool,
ToolsAgents,
User,
)
from letta.orm.agents_tags import AgentsTags
@@ -46,9 +47,10 @@ from letta.schemas.sandbox_config import (
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolUpdate
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.services.block_manager import BlockManager
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
from letta.settings import tool_settings
utils.DEBUG = True
@@ -76,6 +78,7 @@ def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Job))
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
session.execute(delete(BlocksAgents))
session.execute(delete(AgentsTags))
session.execute(delete(SandboxEnvironmentVariable))
@@ -240,6 +243,37 @@ def other_block(server: SyncServer, default_user):
yield block
@pytest.fixture
def other_tool(server: SyncServer, default_user, default_organization):
def print_other_tool(message: str):
"""
Args:
message (str): The message to print.
Returns:
str: The message that was printed.
"""
print(message)
return message
# Set up tool details
source_code = parse_source_code(print_other_tool)
source_type = "python"
description = "other_tool_description"
tags = ["test"]
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
tool = server.tool_manager.create_tool(tool, actor=default_user)
# Yield the created tool
yield tool
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
@@ -1155,6 +1189,107 @@ def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
# ======================================================================================================================
# ToolsAgentsManager Tests
# ======================================================================================================================
def test_add_tool_to_agent(server, sarah_agent, default_user, print_tool):
tool_association = server.tools_agents_manager.add_tool_to_agent(
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
)
assert tool_association.agent_id == sarah_agent.id
assert tool_association.tool_id == print_tool.id
assert tool_association.tool_name == print_tool.name
def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent, default_user, print_tool):
# Add the tool
tool_association = server.tools_agents_manager.add_tool_to_agent(
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
)
assert tool_association.tool_name == print_tool.name
# Change the tool name
new_name = "banana"
tool = server.tool_manager.update_tool_by_id(
tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user
)
assert tool.name == new_name
# Get the association
names = server.tools_agents_manager.list_tool_names_for_agent(agent_id=sarah_agent.id)
assert new_name in names
assert print_tool.name not in names
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.tools_agents_manager.add_tool_to_agent(
agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name"
)
def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
with pytest.warns(UserWarning, match=f"Tool name '{print_tool.name}' already exists for agent '{sarah_agent.id}'"):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=print_tool.name)
def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(
agent_id=sarah_agent.id, tool_name=print_tool.name
)
assert removed_tool.tool_name == print_tool.name
assert removed_tool.tool_id == print_tool.id
assert removed_tool.agent_id == sarah_agent.id
with pytest.raises(ValueError, match=f"Tool name '{print_tool.name}' not found for agent '{sarah_agent.id}'"):
server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name)
def test_list_tool_ids_for_agent(server, sarah_agent, default_user, print_tool, other_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
assert set(retrieved_tool_ids) == {print_tool.id, other_tool.id}
def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_user, print_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
server.tools_agents_manager.add_tool_to_agent(agent_id=charles_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
agent_ids = server.tools_agents_manager.list_agent_ids_with_tool(tool_id=print_tool.id)
assert sarah_agent.id in agent_ids
assert charles_agent.id in agent_ids
assert len(agent_ids) == 2
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
tool_manager = ToolManager()
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)
with pytest.raises(ForeignKeyConstraintViolationError):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
server.tools_agents_manager.remove_all_agent_tools(agent_id=sarah_agent.id)
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
assert not retrieved_tool_ids
# ======================================================================================================================
# JobManager Tests
# ======================================================================================================================