From 8e93497dfc615e722dabcf0bf5e0d6284ff1d591 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 17 Mar 2025 17:39:59 -0700 Subject: [PATCH] feat: Add metadata to MCP tools (#1325) --- letta/server/rest_api/routers/v1/tools.py | 2 +- letta/services/tool_manager.py | 10 +++++-- tests/test_managers.py | 35 ++++++++++++++++++++++- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index b4423027..b302a569 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -398,7 +398,7 @@ def add_mcp_tool( ) tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) - return server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, actor=actor) + return server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor) @router.put("/mcp/servers", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="add_mcp_server") diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index e5eb1d2f..2b6e5a28 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -2,7 +2,7 @@ import importlib import warnings from typing import List, Optional -from letta.constants import BASE_FUNCTION_RETURN_CHAR_LIMIT, BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS +from letta.constants import BASE_FUNCTION_RETURN_CHAR_LIMIT, BASE_MEMORY_TOOLS, BASE_TOOLS, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS from letta.functions.functions import derive_openai_json_schema, load_function_set from letta.log import get_logger from letta.orm.enums import ToolType @@ -57,9 +57,13 @@ class ToolManager: return tool @enforce_types - def create_or_update_mcp_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: + def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool: + metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}} return self.create_or_update_tool( - PydanticTool(tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor + PydanticTool( + tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump() + ), + actor, ) @enforce_types diff --git a/tests/test_managers.py b/tests/test_managers.py index e8da8406..6bd11b18 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -8,9 +8,10 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open from sqlalchemy.exc import IntegrityError from letta.config import LettaConfig -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_TOOL_EXECUTION_DIR, MULTI_AGENT_TOOLS +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_TOOL_EXECUTION_DIR, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code +from letta.functions.mcp_client.types import MCPTool from letta.orm import Base from letta.orm.enums import JobType, ToolType from letta.orm.errors import NoResultFound, UniqueConstraintViolationError @@ -167,6 +168,30 @@ def composio_github_star_tool(server, default_user): yield tool +@pytest.fixture +def mcp_tool(server, default_user): + mcp_tool = MCPTool( + name="weather_lookup", + description="Fetches the current weather for a given location.", + inputSchema={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The name of the city or location."}, + "units": { + "type": "string", + "enum": ["metric", "imperial"], + "description": "The unit system for temperature (metric or imperial).", + }, + }, + "required": ["location"], + }, + ) + mcp_server_name = "test" + tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) + tool = server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user) + yield tool + + @pytest.fixture def default_job(server: SyncServer, default_user): """Fixture to create and return a default job.""" @@ -1817,6 +1842,14 @@ def test_create_composio_tool(server: SyncServer, composio_github_star_tool, def assert composio_github_star_tool.tool_type == ToolType.EXTERNAL_COMPOSIO +def test_create_mcp_tool(server: SyncServer, mcp_tool, default_user, default_organization): + # Assertions to ensure the created tool matches the expected values + assert mcp_tool.created_by_id == default_user.id + assert mcp_tool.organization_id == default_organization.id + assert mcp_tool.tool_type == ToolType.EXTERNAL_MCP + assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == "test" + + @pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.") def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization): data = print_tool.model_dump(exclude=["id"])