Files
letta-server/tests/managers/test_mcp_manager.py
jnjpng 00ba2d09f3 refactor: migrate mcp_servers and mcp_oauth to encrypted-only columns (#6751)
* refactor: migrate mcp_servers and mcp_oauth to encrypted-only columns

Complete migration to encrypted-only storage for sensitive fields:

- Remove dual-write to plaintext columns (token, custom_headers,
  authorization_code, access_token, refresh_token, client_secret)
- Read only from _enc columns, not from plaintext fallback
- Remove helper methods (get_token_secret, set_token_secret, etc.)
- Remove Secret.from_db() and Secret.to_dict() methods
- Update tests to verify encrypted-only behavior

After this change, plaintext columns can be set to NULL manually
since they are no longer read from or written to.

* fix test

* rename

* update

* union

* fix test
2025-12-17 17:31:02 -08:00

1143 lines
47 KiB
Python

import json
import logging
import os
import random
import re
import string
import time
import uuid
from datetime import datetime, timedelta, timezone
from typing import List
from unittest.mock import AsyncMock, Mock, patch
import pytest
from _pytest.python_api import approx
from anthropic.types.beta import BetaMessage
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
# Import shared fixtures and constants from conftest
from conftest import (
CREATE_DELAY_SQLITE,
DEFAULT_EMBEDDING_CONFIG,
USING_SQLITE,
)
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, InvalidRequestError
from sqlalchemy.orm.exc import StaleDataError
from letta.config import LettaConfig
from letta.constants import (
BASE_MEMORY_TOOLS,
BASE_SLEEPTIME_TOOLS,
BASE_TOOLS,
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
BASE_VOICE_SLEEPTIME_TOOLS,
BUILTIN_TOOLS,
DEFAULT_ORG_ID,
DEFAULT_ORG_NAME,
FILES_TOOLS,
LETTA_TOOL_EXECUTION_DIR,
LETTA_TOOL_SET,
LOCAL_ONLY_MULTI_AGENT_TOOLS,
MCP_TOOL_TAG_NAME_PREFIX,
MULTI_AGENT_TOOLS,
)
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import LettaAgentNotFoundError
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import AsyncTimer
from letta.jobs.types import ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
from letta.orm import Base, Block
from letta.orm.block_history import BlockHistory
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import (
ActorType,
AgentStepStatus,
FileProcessingStatus,
JobStatus,
JobType,
MessageRole,
ProviderType,
SandboxType,
StepStatus,
TagMatchMode,
ToolType,
VectorDBProvider,
)
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.file import FileMetadata, FileMetadata as PydanticFileMetadata
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert
from letta.schemas.job import BatchJob, Job, Job as PydanticJob, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_batch_job import AgentStepState, LLMBatchItem
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization, Organization as PydanticOrganization, OrganizationUpdate
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.run import Run as PydanticRun
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.source import Source as PydanticSource, SourceUpdate
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
from letta.schemas.tool_rule import InitToolRule
from letta.schemas.user import User as PydanticUser, UserUpdate
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools, validate_agent_exists_async
from letta.services.step_manager import FeedbackType
from letta.settings import settings, tool_settings
from letta.utils import calculate_file_defaults_based_on_context_window
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
from tests.utils import random_string
# ======================================================================================================================
# MCPManager Tests
# ======================================================================================================================
@pytest.mark.asyncio
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server(mock_get_client, server, default_user):
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# create mock client with required methods
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock()
mock_client.list_tools = AsyncMock(
return_value=[
MCPTool(
name="get_simple_price",
inputSchema={
"type": "object",
"properties": {
"ids": {"type": "string"},
"vs_currencies": {"type": "string"},
"include_market_cap": {"type": "boolean"},
"include_24hr_vol": {"type": "boolean"},
"include_24hr_change": {"type": "boolean"},
},
"required": ["ids", "vs_currencies"],
"additionalProperties": False,
},
)
]
)
mock_client.execute_tool = AsyncMock(
return_value=(
'{"bitcoin": {"usd": 50000, "usd_market_cap": 900000000000, "usd_24h_vol": 30000000000, "usd_24h_change": 2.5}}',
True,
)
)
mock_get_client.return_value = mock_client
# Test with a valid StdioServerConfig
server_config = StdioServerConfig(
server_name="test_server", type=MCPServerType.STDIO, command="echo 'test'", args=["arg1", "arg2"], env={"ENV1": "value1"}
)
mcp_server = MCPServer(server_name="test_server", server_type=MCPServerType.STDIO, stdio_config=server_config)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
print(created_server)
assert created_server.server_name == server_config.server_name
assert created_server.server_type == server_config.type
# Test with a valid SSEServerConfig
mcp_server_name = "coingecko"
server_url = "https://mcp.api.coingecko.com/sse"
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
mcp_sse_server = MCPServer(server_name=mcp_server_name, server_type=MCPServerType.SSE, server_url=server_url)
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_sse_server, actor=default_user)
print(created_server)
assert created_server.server_name == mcp_server_name
assert created_server.server_type == MCPServerType.SSE
# list mcp servers
servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
print(servers)
assert len(servers) > 0, "No MCP servers found"
# list tools from sse server
tools = await server.mcp_manager.list_mcp_server_tools(created_server.server_name, actor=default_user)
print(tools)
# call a tool from the sse server
tool_name = "get_simple_price"
tool_args = {
"ids": "bitcoin",
"vs_currencies": "usd",
"include_market_cap": True,
"include_24hr_vol": True,
"include_24hr_change": True,
}
result = await server.mcp_manager.execute_mcp_server_tool(
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user, environment_variables={}
)
print(result)
# add a tool
tool = await server.mcp_manager.add_tool_from_mcp_server(created_server.server_name, tool_name, actor=default_user)
print(tool)
assert tool.name == tool_name
assert f"mcp:{created_server.server_name}" in tool.tags, f"Expected tag {f'mcp:{created_server.server_name}'}, got {tool.tags}"
print("TAGS", tool.tags)
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_tools(mock_get_client, server, default_user):
"""Test that creating an MCP server automatically syncs and persists its tools."""
from letta.functions.mcp_client.types import MCPToolHealth
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create mock tools with different health statuses
mock_tools = [
MCPTool(
name="valid_tool_1",
description="A valid tool",
inputSchema={
"type": "object",
"properties": {
"param1": {"type": "string"},
},
"required": ["param1"],
},
health=MCPToolHealth(status="VALID", reasons=[]),
),
MCPTool(
name="valid_tool_2",
description="Another valid tool",
inputSchema={
"type": "object",
"properties": {
"param2": {"type": "number"},
},
},
health=MCPToolHealth(status="VALID", reasons=[]),
),
MCPTool(
name="invalid_tool",
description="An invalid tool that should be skipped",
inputSchema={
"type": "invalid_type", # Invalid schema
},
health=MCPToolHealth(status="INVALID", reasons=["Invalid schema type"]),
),
MCPTool(
name="warning_tool",
description="A tool with warnings but should still be persisted",
inputSchema={
"type": "object",
"properties": {},
},
health=MCPToolHealth(status="WARNING", reasons=["No properties defined"]),
),
]
# Create mock client
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tools)
mock_client.cleanup = AsyncMock()
mock_get_client.return_value = mock_client
# Create MCP server config
server_name = f"test_server_{uuid.uuid4().hex[:8]}"
server_url = "https://test-with-tools.example.com/sse"
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
# Create server with tools using the new method
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
# Verify server was created
assert created_server.server_name == server_name
assert created_server.server_type == MCPServerType.SSE
assert created_server.server_url == server_url
# Verify tools were persisted (all except the invalid one)
# Get all tools and filter by checking metadata
all_tools = await server.tool_manager.list_tools_async(
actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool", "invalid_tool"]
)
# Filter tools that belong to our MCP server
persisted_tools = [
tool
for tool in all_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
# Should have 3 tools (2 valid + 1 warning, but not the invalid one)
assert len(persisted_tools) == 3, f"Expected 3 tools, got {len(persisted_tools)}"
# Check tool names
tool_names = {tool.name for tool in persisted_tools}
assert "valid_tool_1" in tool_names
assert "valid_tool_2" in tool_names
assert "warning_tool" in tool_names
assert "invalid_tool" not in tool_names # Invalid tool should be filtered out
# Verify each tool has correct metadata
for tool in persisted_tools:
assert tool.metadata_ is not None
assert MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == server_name
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_id"] == created_server.id
assert tool.tool_type == ToolType.EXTERNAL_MCP
# Clean up - delete the server
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
# Verify tools were also deleted (cascade) by trying to get them again
remaining_tools = await server.tool_manager.list_tools_async(actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool"])
# Filter to see if any still belong to our deleted server
remaining_mcp_tools = [
tool
for tool in remaining_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
assert len(remaining_mcp_tools) == 0, "Tools should be deleted when server is deleted"
@pytest.mark.asyncio
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_complex_schema_normalization(mock_get_client, server, default_user):
"""Test that complex MCP schemas with nested objects are normalized and accepted."""
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
from letta.schemas.mcp import MCPServer, MCPServerType
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create mock tools with complex schemas that would normally be INVALID
# These schemas have: nested $defs, $ref references, missing additionalProperties
mock_tools = [
# 1. Nested object with $ref (like create_person)
MCPTool(
name="create_person",
description="Create a person with nested address",
inputSchema={
"$defs": {
"Address": {
"type": "object",
"properties": {
"street": {"type": "string"},
"city": {"type": "string"},
"zip_code": {"type": "string"},
},
"required": ["street", "city", "zip_code"],
},
"Person": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"address": {"$ref": "#/$defs/Address"},
},
"required": ["name", "age"],
},
},
"type": "object",
"properties": {"person": {"$ref": "#/$defs/Person"}},
"required": ["person"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.person: Missing 'type'"],
),
),
# 2. List of objects (like manage_tasks)
MCPTool(
name="manage_tasks",
description="Manage multiple tasks",
inputSchema={
"$defs": {
"TaskItem": {
"type": "object",
"properties": {
"title": {"type": "string"},
"priority": {"type": "integer", "default": 1},
"completed": {"type": "boolean", "default": False},
"tags": {"type": "array", "items": {"type": "string"}},
},
"required": ["title"],
}
},
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {"$ref": "#/$defs/TaskItem"},
}
},
"required": ["tasks"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.tasks.items: Missing 'type'"],
),
),
# 3. Complex filter object with optional fields
MCPTool(
name="search_with_filters",
description="Search with complex filters",
inputSchema={
"$defs": {
"SearchFilter": {
"type": "object",
"properties": {
"keywords": {"type": "array", "items": {"type": "string"}},
"min_score": {"type": "number"},
"categories": {"type": "array", "items": {"type": "string"}},
},
"required": ["keywords"],
}
},
"type": "object",
"properties": {
"query": {"type": "string"},
"filters": {"$ref": "#/$defs/SearchFilter"},
},
"required": ["query", "filters"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.filters: Missing 'type'"],
),
),
]
# Create mock client
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tools)
mock_client.cleanup = AsyncMock()
mock_get_client.return_value = mock_client
# Create MCP server
server_name = f"test_complex_schema_{uuid.uuid4().hex[:8]}"
server_url = "https://test-complex.example.com/sse"
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
try:
# Create server (this will auto-sync tools)
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
assert created_server.server_name == server_name
# Now attempt to add each tool - they should be normalized from INVALID to acceptable
# The normalization happens in add_tool_from_mcp_server
# Test 1: create_person should normalize successfully
person_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "create_person", actor=default_user)
assert person_tool is not None
assert person_tool.name == "create_person"
# Verify the schema has additionalProperties set
assert person_tool.json_schema["parameters"]["additionalProperties"] == False
# Verify nested $defs have additionalProperties
if "$defs" in person_tool.json_schema["parameters"]:
for def_name, def_schema in person_tool.json_schema["parameters"]["$defs"].items():
if def_schema.get("type") == "object":
assert "additionalProperties" in def_schema, f"$defs.{def_name} missing additionalProperties after normalization"
# Test 2: manage_tasks should normalize successfully
tasks_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "manage_tasks", actor=default_user)
assert tasks_tool is not None
assert tasks_tool.name == "manage_tasks"
# Verify array items have explicit type
tasks_prop = tasks_tool.json_schema["parameters"]["properties"]["tasks"]
assert "items" in tasks_prop
assert "type" in tasks_prop["items"], "Array items should have explicit type after normalization"
# Test 3: search_with_filters should normalize successfully
search_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "search_with_filters", actor=default_user)
assert search_tool is not None
assert search_tool.name == "search_with_filters"
# Verify all tools were persisted
all_tools = await server.tool_manager.list_tools_async(
actor=default_user, names=["create_person", "manage_tasks", "search_with_filters"]
)
# Filter to tools from our MCP server
mcp_tools = [
tool
for tool in all_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
# All 3 complex schema tools should have been normalized and persisted
assert len(mcp_tools) == 3, f"Expected 3 normalized tools, got {len(mcp_tools)}"
# Verify they all have the correct MCP metadata
for tool in mcp_tools:
assert tool.tool_type == ToolType.EXTERNAL_MCP
assert f"mcp:{server_name}" in tool.tags
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_tools_connection_failure(mock_get_client, server, default_user):
"""Test that MCP server creation succeeds even when tool sync fails (optimistic approach)."""
from letta.schemas.mcp import MCPServer, MCPServerType
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create mock client that fails to connect
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.cleanup = AsyncMock()
mock_get_client.return_value = mock_client
# Create MCP server config
server_name = f"test_server_fail_{uuid.uuid4().hex[:8]}"
server_url = "https://test-fail.example.com/sse"
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
# Create server with tools - should succeed despite connection failure
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
# Verify server was created successfully
assert created_server.server_name == server_name
assert created_server.server_type == MCPServerType.SSE
assert created_server.server_url == server_url
# Verify no tools were persisted (due to connection failure)
# Try to get tools by the names we would have expected
all_tools = await server.tool_manager.list_tools_async(
actor=default_user,
names=["tool1", "tool2", "tool3"], # Generic names since we don't know what tools would have been listed
)
# Filter to see if any belong to our server (there shouldn't be any)
persisted_tools = [
tool
for tool in all_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
assert len(persisted_tools) == 0, "No tools should be persisted when connection fails"
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
async def test_get_mcp_servers_by_ids(server, default_user):
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create multiple MCP servers for testing
servers_data = [
{
"name": "test_server_1",
"config": StdioServerConfig(
server_name="test_server_1", type=MCPServerType.STDIO, command="echo 'test1'", args=["arg1"], env={"ENV1": "value1"}
),
"type": MCPServerType.STDIO,
},
{
"name": "test_server_2",
"config": SSEServerConfig(server_name="test_server_2", server_url="https://test2.example.com/sse"),
"type": MCPServerType.SSE,
},
{
"name": "test_server_3",
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/mcp"),
"type": MCPServerType.STREAMABLE_HTTP,
},
]
created_servers = []
for server_data in servers_data:
if server_data["type"] == MCPServerType.STDIO:
mcp_server = MCPServer(server_name=server_data["name"], server_type=server_data["type"], stdio_config=server_data["config"])
else:
mcp_server = MCPServer(
server_name=server_data["name"], server_type=server_data["type"], server_url=server_data["config"].server_url
)
created = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
created_servers.append(created)
# Test fetching multiple servers by IDs
server_ids = [s.id for s in created_servers]
fetched_servers = await server.mcp_manager.get_mcp_servers_by_ids(server_ids, actor=default_user)
assert len(fetched_servers) == len(created_servers)
fetched_ids = {s.id for s in fetched_servers}
expected_ids = {s.id for s in created_servers}
assert fetched_ids == expected_ids
# Test fetching subset of servers
subset_ids = server_ids[:2]
subset_servers = await server.mcp_manager.get_mcp_servers_by_ids(subset_ids, actor=default_user)
assert len(subset_servers) == 2
assert all(s.id in subset_ids for s in subset_servers)
# Test fetching with empty list
empty_result = await server.mcp_manager.get_mcp_servers_by_ids([], actor=default_user)
assert empty_result == []
# Test fetching with non-existent ID mixed with valid IDs
mixed_ids = [server_ids[0], "non-existent-id", server_ids[1]]
mixed_result = await server.mcp_manager.get_mcp_servers_by_ids(mixed_ids, actor=default_user)
# Should only return the existing servers
assert len(mixed_result) == 2
assert all(s.id in server_ids for s in mixed_result)
# Test that servers from different organizations are not returned
# This would require creating another user/org, but for now we'll just verify
# that the function respects the actor's organization
all_servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
all_server_ids = [s.id for s in all_servers]
bulk_fetched = await server.mcp_manager.get_mcp_servers_by_ids(all_server_ids, actor=default_user)
# All fetched servers should belong to the same organization
assert all(s.organization_id == default_user.organization_id for s in bulk_fetched)
# Additional MCPManager OAuth session tests
@pytest.mark.asyncio
async def test_mcp_server_deletion_cascades_oauth_sessions(server, default_organization, default_user):
"""Deleting an MCP server deletes associated OAuth sessions (same user + URL)."""
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
test_server_url = "https://test.example.com/mcp"
# Create orphaned OAuth sessions (no server id) for same user and URL
created_session_ids: list[str] = []
for i in range(3):
session = await server.mcp_manager.create_oauth_session(
MCPOAuthSessionCreate(
server_url=test_server_url,
server_name=f"test_mcp_server_{i}",
user_id=default_user.id,
organization_id=default_organization.id,
),
actor=default_user,
)
created_session_ids.append(session.id)
# Create the MCP server with the same URL
created_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"test_mcp_server_{str(uuid.uuid4().hex[:8])}", # ensure unique name
server_type=MCPServerType.SSE,
server_url=test_server_url,
organization_id=default_organization.id,
),
actor=default_user,
)
# Now delete the server via manager
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
# Verify all sessions are gone
for sid in created_session_ids:
session = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user)
assert session is None, f"OAuth session {sid} should be deleted"
@pytest.mark.asyncio
async def test_oauth_sessions_with_different_url_persist(server, default_organization, default_user):
"""Sessions with different URL should not be deleted when deleting the server for another URL."""
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
server_url = "https://test.example.com/mcp"
other_url = "https://other.example.com/mcp"
# Create a session for other_url (should persist)
other_session = await server.mcp_manager.create_oauth_session(
MCPOAuthSessionCreate(
server_url=other_url,
server_name="standalone_oauth",
user_id=default_user.id,
organization_id=default_organization.id,
),
actor=default_user,
)
# Create the MCP server at server_url
created_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"test_mcp_server_{str(uuid.uuid4().hex[:8])}",
server_type=MCPServerType.SSE,
server_url=server_url,
organization_id=default_organization.id,
),
actor=default_user,
)
# Delete the server at server_url
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
# Verify the session at other_url still exists
persisted = await server.mcp_manager.get_oauth_session_by_id(other_session.id, actor=default_user)
assert persisted is not None, "OAuth session with different URL should persist"
@pytest.mark.asyncio
async def test_mcp_server_creation_links_orphaned_sessions(server, default_organization, default_user):
"""Creating a server should link any existing orphaned sessions (same user + URL)."""
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
server_url = "https://test-atomic-create.example.com/mcp"
# Pre-create orphaned sessions (no server_id) for same user + URL
orphaned_ids: list[str] = []
for i in range(3):
session = await server.mcp_manager.create_oauth_session(
MCPOAuthSessionCreate(
server_url=server_url,
server_name=f"atomic_session_{i}",
user_id=default_user.id,
organization_id=default_organization.id,
),
actor=default_user,
)
orphaned_ids.append(session.id)
# Create server
created_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"test_atomic_server_{str(uuid.uuid4().hex[:8])}",
server_type=MCPServerType.SSE,
server_url=server_url,
organization_id=default_organization.id,
),
actor=default_user,
)
# Sessions should still be retrievable via manager API
for sid in orphaned_ids:
s = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user)
assert s is not None
# Indirect verification: deleting the server removes sessions for that URL+user
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
for sid in orphaned_ids:
assert await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user) is None
@pytest.mark.asyncio
async def test_mcp_server_delete_removes_all_sessions_for_url_and_user(server, default_organization, default_user):
"""Deleting a server removes both linked and orphaned sessions for same user+URL."""
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
server_url = "https://test-atomic-cleanup.example.com/mcp"
# Create orphaned session
orphaned = await server.mcp_manager.create_oauth_session(
MCPOAuthSessionCreate(
server_url=server_url,
server_name="orphaned",
user_id=default_user.id,
organization_id=default_organization.id,
),
actor=default_user,
)
# Create server
created_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"cleanup_server_{str(uuid.uuid4().hex[:8])}",
server_type=MCPServerType.SSE,
server_url=server_url,
organization_id=default_organization.id,
),
actor=default_user,
)
# Delete server
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
# Both orphaned and any linked sessions for that URL+user should be gone
assert await server.mcp_manager.get_oauth_session_by_id(orphaned.id, actor=default_user) is None
@pytest.mark.asyncio
async def test_mcp_server_resync_tools(server, default_user, default_organization):
"""Test that resyncing MCP server tools correctly handles added, deleted, and updated tools."""
from unittest.mock import AsyncMock, MagicMock, patch
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
from letta.schemas.mcp import MCPServer as PydanticMCPServer, MCPServerType
from letta.schemas.tool import ToolCreate
# Create MCP server
mcp_server = await server.mcp_manager.create_mcp_server(
PydanticMCPServer(
server_name=f"test_resync_{uuid.uuid4().hex[:8]}",
server_type=MCPServerType.SSE,
server_url="https://test-resync.example.com/mcp",
organization_id=default_organization.id,
),
actor=default_user,
)
mcp_server_id = mcp_server.id
try:
# Create initial persisted tools (simulating previously added tools)
# Use sync method like in the existing mcp_tool fixture
tool1_create = ToolCreate.from_mcp(
mcp_server_name=mcp_server.server_name,
mcp_tool=MCPTool(
name="tool1",
description="Tool 1",
inputSchema={"type": "object", "properties": {"param1": {"type": "string"}}},
),
)
tool1 = await server.tool_manager.create_or_update_mcp_tool_async(
tool_create=tool1_create,
mcp_server_name=mcp_server.server_name,
mcp_server_id=mcp_server_id,
actor=default_user,
)
tool2_create = ToolCreate.from_mcp(
mcp_server_name=mcp_server.server_name,
mcp_tool=MCPTool(
name="tool2",
description="Tool 2 to be deleted",
inputSchema={"type": "object", "properties": {"param2": {"type": "number"}}},
),
)
tool2 = await server.tool_manager.create_or_update_mcp_tool_async(
tool_create=tool2_create,
mcp_server_name=mcp_server.server_name,
mcp_server_id=mcp_server_id,
actor=default_user,
)
# Mock the list_mcp_server_tools to return updated tools from server
# tool1 is updated, tool2 is deleted, tool3 is added
updated_tools = [
MCPTool(
name="tool1",
description="Tool 1 Updated",
inputSchema={"type": "object", "properties": {"param1": {"type": "string"}, "param1b": {"type": "boolean"}}},
health=MCPToolHealth(status="VALID", reasons=[]),
),
MCPTool(
name="tool3",
description="Tool 3 New",
inputSchema={"type": "object", "properties": {"param3": {"type": "array"}}},
health=MCPToolHealth(status="VALID", reasons=[]),
),
]
with patch.object(server.mcp_manager, "list_mcp_server_tools", new_callable=AsyncMock) as mock_list_tools:
mock_list_tools.return_value = updated_tools
# Run resync
result = await server.mcp_manager.resync_mcp_server_tools(
mcp_server_name=mcp_server.server_name,
actor=default_user,
)
# Verify the resync result
assert len(result.deleted) == 1
assert "tool2" in result.deleted
assert len(result.updated) == 1
assert "tool1" in result.updated
assert len(result.added) == 1
assert "tool3" in result.added
# Verify tool2 was actually deleted
try:
deleted_tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool2.id, actor=default_user)
assert False, "Tool2 should have been deleted"
except Exception:
pass # Expected - tool should be deleted
# Verify tool1 was updated with new schema
updated_tool1 = await server.tool_manager.get_tool_by_id_async(tool_id=tool1.id, actor=default_user)
assert "param1b" in updated_tool1.json_schema["parameters"]["properties"]
# Verify tool3 was added
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["tool3"])
assert len(tools) == 1
assert tools[0].name == "tool3"
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(mcp_server_id, actor=default_user)
# ======================================================================================================================
# MCPManager Tests - Encryption
# ======================================================================================================================
@pytest.fixture
def encryption_key():
"""Fixture to ensure encryption key is set for tests."""
original_key = settings.encryption_key
# Set a test encryption key if not already set
if not settings.encryption_key:
settings.encryption_key = "test-encryption-key-32-bytes!!"
yield settings.encryption_key
# Restore original
settings.encryption_key = original_key
@pytest.mark.asyncio
async def test_mcp_server_token_encryption_on_create(server, default_user, encryption_key):
"""Test that creating an MCP server encrypts the token in the database."""
from letta.functions.mcp_client.types import MCPServerType
from letta.orm.mcp_server import MCPServer as MCPServerModel
from letta.schemas.mcp import MCPServer
from letta.schemas.secret import Secret
# Create MCP server with token
mcp_server = MCPServer(
server_name="test-encrypted-server",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="https://api.example.com/mcp",
token="sk-test-secret-token-12345",
)
created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user)
try:
# Verify server was created
assert created_server is not None
assert created_server.server_name == "test-encrypted-server"
# Verify plaintext token field is NOT set (no dual-write)
assert created_server.token is None
# Verify token_enc is a Secret object and decrypts correctly
assert created_server.token_enc is not None
assert isinstance(created_server.token_enc, Secret)
assert created_server.token_enc.get_plaintext() == "sk-test-secret-token-12345"
# Read directly from database to verify encryption
async with db_registry.async_session() as session:
server_orm = await MCPServerModel.read_async(
db_session=session,
identifier=created_server.id,
actor=default_user,
)
# Verify encrypted column is populated and different from plaintext
assert server_orm.token_enc is not None
assert server_orm.token_enc != "sk-test-secret-token-12345"
# Encrypted value should be longer
assert len(server_orm.token_enc) > len("sk-test-secret-token-12345")
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@pytest.mark.asyncio
async def test_mcp_server_token_decryption_on_read(server, default_user, encryption_key):
"""Test that reading an MCP server decrypts the token correctly."""
from letta.functions.mcp_client.types import MCPServerType
from letta.schemas.mcp import MCPServer
from letta.schemas.secret import Secret
# Create MCP server
mcp_server = MCPServer(
server_name="test-decrypt-server",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="https://api.example.com/mcp",
token="sk-test-decrypt-token-67890",
)
created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user)
server_id = created_server.id
try:
# Read the server back
retrieved_server = await server.mcp_manager.get_mcp_server_by_id_async(server_id, actor=default_user)
# Verify plaintext token field is NOT set (no dual-write)
assert retrieved_server.token is None
# Verify the token is decrypted correctly via token_enc
assert retrieved_server.token_enc is not None
assert retrieved_server.token_enc.get_plaintext() == "sk-test-decrypt-token-67890"
# Verify we can get the decrypted token through the secret getter
token_secret = retrieved_server.get_token_secret()
assert isinstance(token_secret, Secret)
decrypted_token = token_secret.get_plaintext()
assert decrypted_token == "sk-test-decrypt-token-67890"
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(server_id, actor=default_user)
@pytest.mark.asyncio
async def test_mcp_server_custom_headers_encryption(server, default_user, encryption_key):
"""Test that custom headers are encrypted as JSON strings."""
from letta.functions.mcp_client.types import MCPServerType
from letta.orm.mcp_server import MCPServer as MCPServerModel
from letta.schemas.mcp import MCPServer
from letta.schemas.secret import Secret
# Create MCP server with custom headers
custom_headers = {"Authorization": "Bearer token123", "X-API-Key": "secret-key-456"}
mcp_server = MCPServer(
server_name="test-headers-server",
server_type=MCPServerType.STREAMABLE_HTTP,
server_url="https://api.example.com/mcp",
custom_headers=custom_headers,
)
created_server = await server.mcp_manager.create_mcp_server(mcp_server, actor=default_user)
try:
# Verify plaintext custom_headers field is NOT set (no dual-write)
assert created_server.custom_headers is None
# Verify custom_headers are accessible via encrypted field
assert created_server.get_custom_headers_dict() == custom_headers
# Verify custom_headers_enc is a Secret object (stores JSON string)
assert created_server.custom_headers_enc is not None
assert isinstance(created_server.custom_headers_enc, Secret)
# Verify the getter method returns a Secret (JSON string)
headers_secret = created_server.get_custom_headers_secret()
assert isinstance(headers_secret, Secret)
# Verify the Secret contains JSON string
json_str = headers_secret.get_plaintext()
assert json_str is not None
import json
assert json.loads(json_str) == custom_headers
# Verify the convenience method returns dict directly
headers_dict = created_server.get_custom_headers_dict()
assert headers_dict == custom_headers
# Read from DB to verify encryption
async with db_registry.async_session() as session:
server_orm = await MCPServerModel.read_async(
db_session=session,
identifier=created_server.id,
actor=default_user,
)
# Verify encrypted column contains encrypted JSON string
assert server_orm.custom_headers_enc is not None
# Decrypt and verify it's valid JSON matching original headers
decrypted_json = Secret.from_encrypted(server_orm.custom_headers_enc).get_plaintext()
import json
decrypted_headers = json.loads(decrypted_json)
assert decrypted_headers == custom_headers
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@pytest.mark.asyncio
async def test_oauth_session_tokens_encryption(server, default_user, encryption_key):
"""Test that OAuth session tokens are encrypted in the database."""
from letta.orm.mcp_oauth import MCPOAuth as MCPOAuthModel
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPOAuthSessionUpdate
from letta.schemas.secret import Secret
# Create OAuth session
session_create = MCPOAuthSessionCreate(
server_url="https://oauth.example.com",
server_name="test-oauth-server",
organization_id=default_user.organization_id,
user_id=default_user.id,
)
created_session = await server.mcp_manager.create_oauth_session(session_create, actor=default_user)
session_id = created_session.id
try:
# Update with OAuth tokens
session_update = MCPOAuthSessionUpdate(
access_token="access-token-abc123",
refresh_token="refresh-token-xyz789",
client_secret="client-secret-def456",
authorization_code="auth-code-ghi012",
)
updated_session = await server.mcp_manager.update_oauth_session(session_id, session_update, actor=default_user)
# Verify tokens are accessible
assert updated_session.access_token == "access-token-abc123"
assert updated_session.refresh_token == "refresh-token-xyz789"
assert updated_session.client_secret == "client-secret-def456"
assert updated_session.authorization_code == "auth-code-ghi012"
# Verify encrypted fields are Secret objects
assert isinstance(updated_session.access_token_enc, Secret)
assert isinstance(updated_session.refresh_token_enc, Secret)
assert isinstance(updated_session.client_secret_enc, Secret)
assert isinstance(updated_session.authorization_code_enc, Secret)
# Read from DB to verify all tokens are encrypted
async with db_registry.async_session() as session:
oauth_orm = await MCPOAuthModel.read_async(
db_session=session,
identifier=session_id,
actor=default_user,
)
# Verify all encrypted columns are populated and encrypted
assert oauth_orm.access_token_enc is not None
assert oauth_orm.refresh_token_enc is not None
assert oauth_orm.client_secret_enc is not None
assert oauth_orm.authorization_code_enc is not None
# Decrypt and verify
assert Secret.from_encrypted(oauth_orm.access_token_enc).get_plaintext() == "access-token-abc123"
assert Secret.from_encrypted(oauth_orm.refresh_token_enc).get_plaintext() == "refresh-token-xyz789"
assert Secret.from_encrypted(oauth_orm.client_secret_enc).get_plaintext() == "client-secret-def456"
assert Secret.from_encrypted(oauth_orm.authorization_code_enc).get_plaintext() == "auth-code-ghi012"
finally:
# Clean up
await server.mcp_manager.delete_oauth_session(session_id, actor=default_user)