diff --git a/tests/mcp/mcp_config.json b/tests/mcp/mcp_config.json index bf5234fe..0967ef42 100644 --- a/tests/mcp/mcp_config.json +++ b/tests/mcp/mcp_config.json @@ -1,8 +1 @@ -{ - "mcpServers": { - "github_composio": { - "transport": "sse", - "url": "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8" - } - } -} +{} diff --git a/tests/mcp/test_mcp.py b/tests/mcp/test_mcp.py index 01c5d801..5e7550cd 100644 --- a/tests/mcp/test_mcp.py +++ b/tests/mcp/test_mcp.py @@ -1,18 +1,21 @@ import json import os import subprocess +import threading +import uuid import venv from pathlib import Path import pytest -from mcp import Tool as MCPTool +from dotenv import load_dotenv +from letta_client import Letta, McpTool, ToolCallMessage, ToolReturnMessage -import letta.constants as constants -from letta.config import LettaConfig -from letta.functions.mcp_client.types import MCPServerType, SSEServerConfig, StdioServerConfig -from letta.schemas.tool import ToolCreate -from letta.server.server import SyncServer -from letta.utils import parse_json +from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_message_content import TextContent +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate +from tests.utils import wait_for_server def create_virtualenv_and_install_requirements(requirements_path: Path, name="venv") -> Path: @@ -40,122 +43,165 @@ def create_virtualenv_and_install_requirements(requirements_path: Path, name="ve return venv_dir +# --- Server Management --- # + + +def _run_server(): + """Starts the Letta server in a background thread.""" + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + @pytest.fixture -def empty_mcp_config(tmp_path): +def empty_mcp_config(): path = Path(__file__).parent / "mcp_config.json" path.write_text(json.dumps({})) # writes "{}" return path -@pytest.fixture -def server(empty_mcp_config): - config = LettaConfig.load() - print("CONFIG PATH", config.config_path) +@pytest.fixture() +def server_url(empty_mcp_config): + """Ensures a server is running and returns its base URL.""" + url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - config.save() + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + wait_for_server(url) - old_dir = constants.LETTA_DIR - constants.LETTA_DIR = str(Path(__file__).parent) - - server = SyncServer() - yield server - constants.LETTA_DIR = old_dir + return url -@pytest.fixture -def default_user(server): - user = server.user_manager.get_user_or_default() - yield user +@pytest.fixture() +def client(server_url): + """Creates a REST client for testing.""" + client = Letta(base_url=server_url) + return client -def test_sse_mcp_server(server, default_user): - assert server.mcp_clients == {} +@pytest.fixture() +def agent_state(client): + """Creates an agent and ensures cleanup after tests.""" + agent_state = client.agents.create( + name=f"test_compl_{str(uuid.uuid4())[5:]}", + include_base_tools=True, + memory_blocks=[ + { + "label": "human", + "value": "Name: Matt", + }, + { + "label": "persona", + "value": "Friendly agent", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ) + yield agent_state + client.agents.delete(agent_state.id) + +def test_sse_mcp_server(client, agent_state): mcp_server_name = "github_composio" server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8" sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url) - server.add_mcp_server_to_config(sse_mcp_config) - - # Check that it's in clients - assert mcp_server_name in server.mcp_clients + client.tools.add_mcp_server(request=sse_mcp_config) # Check that it's in the server mapping - mcp_server_mapping = server.get_mcp_servers() + mcp_server_mapping = client.tools.list_mcp_servers() assert mcp_server_name in mcp_server_mapping - assert mcp_server_mapping[mcp_server_name] == sse_mcp_config # Check tools - tools = server.get_tools_from_mcp_server(mcp_server_name) + tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name) assert len(tools) > 0 - assert isinstance(tools[0], MCPTool) + assert isinstance(tools[0], McpTool) star_mcp_tool = next((t for t in tools if t.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"), None) # Check that one of the tools are executable - tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=star_mcp_tool) - server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user) + letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=star_mcp_tool.name) - function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool( - tool_name=star_mcp_tool.name, tool_args={"owner": "letta-ai", "repo": "letta"} + tool_args = {"owner": "letta-ai", "repo": "letta"} + + # Add to agent, have agent invoke tool + client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id) + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[ + MessageCreate( + role="user", + content=[TextContent(text=f"Use the `{letta_tool.name}` tool with these arguments: {tool_args}.")], + ) + ], ) - assert not is_error - function_response = parse_json(function_response) - assert function_response.get("successful"), function_response - assert function_response.get("data").get("details") == "Action executed successfully", function_response + seq = response.messages + calls = [m for m in seq if isinstance(m, ToolCallMessage)] + assert calls, "Expected a ToolCallMessage" + assert calls[0].tool_call.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER" + + returns = [m for m in seq if isinstance(m, ToolReturnMessage)] + assert returns, "Expected a ToolReturnMessage" + tr = returns[0] + # status field + assert tr.status == "success", f"Bad status: {tr.status}" + # parse JSON payload + payload = json.loads(tr.tool_return) + assert payload.get("successful", False), f"Tool returned failure payload: {payload}" + assert payload["data"]["details"] == "Action executed successfully", f"Unexpected details: {payload}" -def test_stdio_mcp_server(server, default_user): - assert server.mcp_clients == {} - - # Create venv - create_virtualenv_and_install_requirements(Path(__file__).parent / "weather" / "requirements.txt") +def test_stdio_mcp_server(client, agent_state): + req_file = Path(__file__).parent / "weather" / "requirements.txt" + create_virtualenv_and_install_requirements(req_file, name="venv") mcp_server_name = "weather" command = str(Path(__file__).parent / "weather" / "venv" / "bin" / "python3") args = [str(Path(__file__).parent / "weather" / "weather.py")] - stdio_mcp_config = StdioServerConfig(server_name=mcp_server_name, command=command, args=args) - server.add_mcp_server_to_config(stdio_mcp_config) - # Check that it's in clients - assert mcp_server_name in server.mcp_clients - - # Check that it's in the server mapping - mcp_server_mapping = server.get_mcp_servers() - assert mcp_server_name in mcp_server_mapping - assert mcp_server_mapping[mcp_server_name] == StdioServerConfig( - server_name=mcp_server_name, type=MCPServerType.STDIO, command=command, args=args, env=None + stdio_config = StdioServerConfig( + server_name=mcp_server_name, + command=command, + args=args, ) - # Check that it can return valid tools - tools = server.get_tools_from_mcp_server(mcp_server_name) - assert tools == [ - MCPTool( - name="get_alerts", - description="Get weather alerts for a US state.\n\n Args:\n state: Two-letter US state code (e.g. CA, NY)\n ", - inputSchema={ - "properties": {"state": {"title": "State", "type": "string"}}, - "required": ["state"], - "title": "get_alertsArguments", - "type": "object", - }, - ), - MCPTool( - name="get_forecast", - description="Get weather forecast for a location.\n\n Args:\n latitude: Latitude of the location\n longitude: Longitude of the location\n ", - inputSchema={ - "properties": {"latitude": {"title": "Latitude", "type": "number"}, "longitude": {"title": "Longitude", "type": "number"}}, - "required": ["latitude", "longitude"], - "title": "get_forecastArguments", - "type": "object", - }, - ), - ] - get_alerts_mcp_tool = tools[0] + client.tools.add_mcp_server(request=stdio_config) - tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=get_alerts_mcp_tool) - server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user) + servers = client.tools.list_mcp_servers() + assert mcp_server_name in servers - # Attempt running the tool - function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool(tool_name="get_alerts", tool_args={"state": "CA"}) - assert not is_error - assert len(function_response) > 20, function_response # Crude heuristic for an expected result + tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name) + assert tools, "Expected at least one tool from the weather MCP server" + assert any(t.name == "get_alerts" for t in tools), f"Got: {[t.name for t in tools]}" + + get_alerts = next(t for t in tools if t.name == "get_alerts") + + letta_tool = client.tools.add_mcp_tool( + mcp_server_name=mcp_server_name, + mcp_tool_name=get_alerts.name, + ) + + client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id) + + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[ + MessageCreate( + role="user", + content=[TextContent(text=(f"Use the `{letta_tool.name}` tool with these arguments: " f"{{'state': 'CA'}}."))], + ) + ], + ) + + calls = [m for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == "get_alerts"] + assert calls, "Expected a get_alerts ToolCallMessage" + + returns = [m for m in response.messages if isinstance(m, ToolReturnMessage) and m.tool_call_id == calls[0].tool_call.tool_call_id] + assert returns, "Expected a ToolReturnMessage for get_alerts" + ret = returns[0] + + assert ret.status == "success", f"Unexpected status: {ret.status}" + # make sure there's at least some payload + assert len(ret.tool_return.strip()) >= 10, f"Expected at least 10 characters in tool_return, got {len(ret.tool_return.strip())}"