feat: orm ToolsAgents migration (#2173)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from letta.orm import (
|
||||
SandboxEnvironmentVariable,
|
||||
Source,
|
||||
Tool,
|
||||
ToolsAgents,
|
||||
User,
|
||||
)
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
@@ -46,9 +47,10 @@ from letta.schemas.sandbox_config import (
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source import SourceUpdate
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
utils.DEBUG = True
|
||||
@@ -76,6 +78,7 @@ def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Job))
|
||||
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||
session.execute(delete(BlocksAgents))
|
||||
session.execute(delete(AgentsTags))
|
||||
session.execute(delete(SandboxEnvironmentVariable))
|
||||
@@ -240,6 +243,37 @@ def other_block(server: SyncServer, default_user):
|
||||
yield block
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_tool(server: SyncServer, default_user, default_organization):
|
||||
def print_other_tool(message: str):
|
||||
"""
|
||||
Args:
|
||||
message (str): The message to print.
|
||||
|
||||
Returns:
|
||||
str: The message that was printed.
|
||||
"""
|
||||
print(message)
|
||||
return message
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(print_other_tool)
|
||||
source_type = "python"
|
||||
description = "other_tool_description"
|
||||
tags = ["test"]
|
||||
|
||||
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
tool = server.tool_manager.create_tool(tool, actor=default_user)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
@@ -1155,6 +1189,107 @@ def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# ToolsAgentsManager Tests
|
||||
# ======================================================================================================================
|
||||
def test_add_tool_to_agent(server, sarah_agent, default_user, print_tool):
|
||||
tool_association = server.tools_agents_manager.add_tool_to_agent(
|
||||
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
|
||||
)
|
||||
|
||||
assert tool_association.agent_id == sarah_agent.id
|
||||
assert tool_association.tool_id == print_tool.id
|
||||
assert tool_association.tool_name == print_tool.name
|
||||
|
||||
|
||||
def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent, default_user, print_tool):
|
||||
# Add the tool
|
||||
tool_association = server.tools_agents_manager.add_tool_to_agent(
|
||||
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
|
||||
)
|
||||
assert tool_association.tool_name == print_tool.name
|
||||
|
||||
# Change the tool name
|
||||
new_name = "banana"
|
||||
tool = server.tool_manager.update_tool_by_id(
|
||||
tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user
|
||||
)
|
||||
assert tool.name == new_name
|
||||
|
||||
# Get the association
|
||||
names = server.tools_agents_manager.list_tool_names_for_agent(agent_id=sarah_agent.id)
|
||||
assert new_name in names
|
||||
assert print_tool.name not in names
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.tools_agents_manager.add_tool_to_agent(
|
||||
agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name"
|
||||
)
|
||||
|
||||
|
||||
def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
with pytest.warns(UserWarning, match=f"Tool name '{print_tool.name}' already exists for agent '{sarah_agent.id}'"):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=print_tool.name)
|
||||
|
||||
|
||||
def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(
|
||||
agent_id=sarah_agent.id, tool_name=print_tool.name
|
||||
)
|
||||
|
||||
assert removed_tool.tool_name == print_tool.name
|
||||
assert removed_tool.tool_id == print_tool.id
|
||||
assert removed_tool.agent_id == sarah_agent.id
|
||||
|
||||
with pytest.raises(ValueError, match=f"Tool name '{print_tool.name}' not found for agent '{sarah_agent.id}'"):
|
||||
server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name)
|
||||
|
||||
|
||||
def test_list_tool_ids_for_agent(server, sarah_agent, default_user, print_tool, other_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
|
||||
|
||||
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
|
||||
|
||||
assert set(retrieved_tool_ids) == {print_tool.id, other_tool.id}
|
||||
|
||||
|
||||
def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_user, print_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=charles_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
agent_ids = server.tools_agents_manager.list_agent_ids_with_tool(tool_id=print_tool.id)
|
||||
|
||||
assert sarah_agent.id in agent_ids
|
||||
assert charles_agent.id in agent_ids
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
|
||||
|
||||
server.tools_agents_manager.remove_all_agent_tools(agent_id=sarah_agent.id)
|
||||
|
||||
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
|
||||
|
||||
assert not retrieved_tool_ids
|
||||
|
||||
# ======================================================================================================================
|
||||
# JobManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user