@@ -9,7 +9,7 @@ import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
# tests/test_file_content_flow.py
|
||||
import pytest
|
||||
@@ -9743,13 +9743,44 @@ async def test_count_batch_items(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mcp_server(server, default_user):
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server(mock_get_client, server, default_user):
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return
|
||||
|
||||
# create mock client with required methods
|
||||
mock_client = AsyncMock()
|
||||
mock_client.connect_to_server = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(
|
||||
return_value=[
|
||||
MCPTool(
|
||||
name="get_simple_price",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ids": {"type": "string"},
|
||||
"vs_currencies": {"type": "string"},
|
||||
"include_market_cap": {"type": "boolean"},
|
||||
"include_24hr_vol": {"type": "boolean"},
|
||||
"include_24hr_change": {"type": "boolean"},
|
||||
},
|
||||
"required": ["ids", "vs_currencies"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
mock_client.execute_tool = AsyncMock(
|
||||
return_value=(
|
||||
'{"bitcoin": {"usd": 50000, "usd_market_cap": 900000000000, "usd_24h_vol": 30000000000, "usd_24h_change": 2.5}}',
|
||||
True,
|
||||
)
|
||||
)
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Test with a valid StdioServerConfig
|
||||
server_config = StdioServerConfig(
|
||||
server_name="test_server", type=MCPServerType.STDIO, command="echo 'test'", args=["arg1", "arg2"], env={"ENV1": "value1"}
|
||||
@@ -9761,8 +9792,8 @@ async def test_create_mcp_server(server, default_user):
|
||||
assert created_server.server_type == server_config.type
|
||||
|
||||
# Test with a valid SSEServerConfig
|
||||
mcp_server_name = "devin"
|
||||
server_url = "https://mcp.deepwiki.com/sse"
|
||||
mcp_server_name = "coingecko"
|
||||
server_url = "https://mcp.api.coingecko.com/sse"
|
||||
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
||||
mcp_sse_server = MCPServer(server_name=mcp_server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
||||
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_sse_server, actor=default_user)
|
||||
@@ -9780,8 +9811,14 @@ async def test_create_mcp_server(server, default_user):
|
||||
print(tools)
|
||||
|
||||
# call a tool from the sse server
|
||||
tool_name = "ask_question"
|
||||
tool_args = {"repoName": "letta-ai/letta", "question": "What is the primary programming language of this repository?"}
|
||||
tool_name = "get_simple_price"
|
||||
tool_args = {
|
||||
"ids": "bitcoin",
|
||||
"vs_currencies": "usd",
|
||||
"include_market_cap": True,
|
||||
"include_24hr_vol": True,
|
||||
"include_24hr_change": True,
|
||||
}
|
||||
result = await server.mcp_manager.execute_mcp_server_tool(
|
||||
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user, environment_variables={}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user