feat: Add metadata to MCP tools (#1325)
This commit is contained in:
@@ -398,7 +398,7 @@ def add_mcp_tool(
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, actor=actor)
|
||||
return server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
|
||||
|
||||
|
||||
@router.put("/mcp/servers", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="add_mcp_server")
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import BASE_FUNCTION_RETURN_CHAR_LIMIT, BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS
|
||||
from letta.constants import BASE_FUNCTION_RETURN_CHAR_LIMIT, BASE_MEMORY_TOOLS, BASE_TOOLS, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS
|
||||
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
@@ -57,9 +57,13 @@ class ToolManager:
|
||||
return tool
|
||||
|
||||
@enforce_types
|
||||
def create_or_update_mcp_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
||||
def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
|
||||
return self.create_or_update_tool(
|
||||
PydanticTool(tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor
|
||||
PydanticTool(
|
||||
tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump()
|
||||
),
|
||||
actor,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -8,9 +8,10 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_TOOL_EXECUTION_DIR, MULTI_AGENT_TOOLS
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_TOOL_EXECUTION_DIR, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.orm import Base
|
||||
from letta.orm.enums import JobType, ToolType
|
||||
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
||||
@@ -167,6 +168,30 @@ def composio_github_star_tool(server, default_user):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_tool(server, default_user):
|
||||
mcp_tool = MCPTool(
|
||||
name="weather_lookup",
|
||||
description="Fetches the current weather for a given location.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The name of the city or location."},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["metric", "imperial"],
|
||||
"description": "The unit system for temperature (metric or imperial).",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
mcp_server_name = "test"
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
tool = server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_job(server: SyncServer, default_user):
|
||||
"""Fixture to create and return a default job."""
|
||||
@@ -1817,6 +1842,14 @@ def test_create_composio_tool(server: SyncServer, composio_github_star_tool, def
|
||||
assert composio_github_star_tool.tool_type == ToolType.EXTERNAL_COMPOSIO
|
||||
|
||||
|
||||
def test_create_mcp_tool(server: SyncServer, mcp_tool, default_user, default_organization):
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert mcp_tool.created_by_id == default_user.id
|
||||
assert mcp_tool.organization_id == default_organization.id
|
||||
assert mcp_tool.tool_type == ToolType.EXTERNAL_MCP
|
||||
assert mcp_tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
||||
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
||||
data = print_tool.model_dump(exclude=["id"])
|
||||
|
||||
Reference in New Issue
Block a user