diff --git a/tests/test_managers.py b/tests/test_managers.py index c79caf93..6700e370 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -9,7 +9,7 @@ import time import uuid from datetime import datetime, timedelta, timezone from typing import List -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock, patch # tests/test_file_content_flow.py import pytest @@ -9743,13 +9743,44 @@ async def test_count_batch_items( @pytest.mark.asyncio -async def test_create_mcp_server(server, default_user): +@patch("letta.services.mcp_manager.MCPManager.get_mcp_client") +async def test_create_mcp_server(mock_get_client, server, default_user): from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig from letta.settings import tool_settings if tool_settings.mcp_read_from_config: return + # create mock client with required methods + mock_client = AsyncMock() + mock_client.connect_to_server = AsyncMock() + mock_client.list_tools = AsyncMock( + return_value=[ + MCPTool( + name="get_simple_price", + inputSchema={ + "type": "object", + "properties": { + "ids": {"type": "string"}, + "vs_currencies": {"type": "string"}, + "include_market_cap": {"type": "boolean"}, + "include_24hr_vol": {"type": "boolean"}, + "include_24hr_change": {"type": "boolean"}, + }, + "required": ["ids", "vs_currencies"], + "additionalProperties": False, + }, + ) + ] + ) + mock_client.execute_tool = AsyncMock( + return_value=( + '{"bitcoin": {"usd": 50000, "usd_market_cap": 900000000000, "usd_24h_vol": 30000000000, "usd_24h_change": 2.5}}', + True, + ) + ) + mock_get_client.return_value = mock_client + # Test with a valid StdioServerConfig server_config = StdioServerConfig( server_name="test_server", type=MCPServerType.STDIO, command="echo 'test'", args=["arg1", "arg2"], env={"ENV1": "value1"} @@ -9761,8 +9792,8 @@ async def test_create_mcp_server(server, default_user): assert created_server.server_type == server_config.type # Test with a valid SSEServerConfig - mcp_server_name = "devin" - server_url = "https://mcp.deepwiki.com/sse" + mcp_server_name = "coingecko" + server_url = "https://mcp.api.coingecko.com/sse" sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url) mcp_sse_server = MCPServer(server_name=mcp_server_name, server_type=MCPServerType.SSE, server_url=server_url) created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_sse_server, actor=default_user) @@ -9780,8 +9811,14 @@ async def test_create_mcp_server(server, default_user): print(tools) # call a tool from the sse server - tool_name = "ask_question" - tool_args = {"repoName": "letta-ai/letta", "question": "What is the primary programming language of this repository?"} + tool_name = "get_simple_price" + tool_args = { + "ids": "bitcoin", + "vs_currencies": "usd", + "include_market_cap": True, + "include_24hr_vol": True, + "include_24hr_change": True, + } result = await server.mcp_manager.execute_mcp_server_tool( created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user, environment_variables={} )