* add test_agent_manager.py * created shared conftest * add test_tool_manager.py * add tag tests * add message manager tests * add blocks * add org * add passage tests * add archive manager * add user manager * add identity * add job manager tests * add sandbox manager * add file manager * add group managers * add mcp manager * fix batch tests * update workflows * fix test_managers.py * more tests * comment out old test and add file --------- Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
721 lines
29 KiB
Python
721 lines
29 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import string
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import List
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import pytest
|
|
from _pytest.python_api import approx
|
|
from anthropic.types.beta import BetaMessage
|
|
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
|
|
|
|
# Import shared fixtures and constants from conftest
|
|
from conftest import (
|
|
CREATE_DELAY_SQLITE,
|
|
DEFAULT_EMBEDDING_CONFIG,
|
|
USING_SQLITE,
|
|
)
|
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.exc import IntegrityError, InvalidRequestError
|
|
from sqlalchemy.orm.exc import StaleDataError
|
|
|
|
from letta.config import LettaConfig
|
|
from letta.constants import (
|
|
BASE_MEMORY_TOOLS,
|
|
BASE_SLEEPTIME_TOOLS,
|
|
BASE_TOOLS,
|
|
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
|
|
BASE_VOICE_SLEEPTIME_TOOLS,
|
|
BUILTIN_TOOLS,
|
|
DEFAULT_ORG_ID,
|
|
DEFAULT_ORG_NAME,
|
|
FILES_TOOLS,
|
|
LETTA_TOOL_EXECUTION_DIR,
|
|
LETTA_TOOL_SET,
|
|
LOCAL_ONLY_MULTI_AGENT_TOOLS,
|
|
MCP_TOOL_TAG_NAME_PREFIX,
|
|
MULTI_AGENT_TOOLS,
|
|
)
|
|
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
|
from letta.errors import LettaAgentNotFoundError
|
|
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
|
from letta.functions.mcp_client.types import MCPTool
|
|
from letta.helpers import ToolRulesSolver
|
|
from letta.helpers.datetime_helpers import AsyncTimer
|
|
from letta.jobs.types import ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
|
|
from letta.orm import Base, Block
|
|
from letta.orm.block_history import BlockHistory
|
|
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
|
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
|
|
from letta.schemas.agent import CreateAgent, UpdateAgent
|
|
from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import (
|
|
ActorType,
|
|
AgentStepStatus,
|
|
FileProcessingStatus,
|
|
JobStatus,
|
|
JobType,
|
|
MessageRole,
|
|
ProviderType,
|
|
SandboxType,
|
|
StepStatus,
|
|
TagMatchMode,
|
|
ToolType,
|
|
VectorDBProvider,
|
|
)
|
|
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
|
from letta.schemas.file import FileMetadata, FileMetadata as PydanticFileMetadata
|
|
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert
|
|
from letta.schemas.job import BatchJob, Job, Job as PydanticJob, JobUpdate, LettaRequestConfig
|
|
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
|
from letta.schemas.llm_batch_job import AgentStepState, LLMBatchItem
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message as PydanticMessage, MessageCreate, MessageUpdate
|
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
|
from letta.schemas.organization import Organization, Organization as PydanticOrganization, OrganizationUpdate
|
|
from letta.schemas.passage import Passage as PydanticPassage
|
|
from letta.schemas.pip_requirement import PipRequirement
|
|
from letta.schemas.run import Run as PydanticRun
|
|
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
|
|
from letta.schemas.source import Source as PydanticSource, SourceUpdate
|
|
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
|
from letta.schemas.tool_rule import InitToolRule
|
|
from letta.schemas.user import User as PydanticUser, UserUpdate
|
|
from letta.server.db import db_registry
|
|
from letta.server.server import SyncServer
|
|
from letta.services.block_manager import BlockManager
|
|
from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools, validate_agent_exists_async
|
|
from letta.services.step_manager import FeedbackType
|
|
from letta.settings import settings, tool_settings
|
|
from letta.utils import calculate_file_defaults_based_on_context_window
|
|
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
|
|
from tests.utils import random_string
|
|
|
|
# ======================================================================================================================
|
|
# MCPManager Tests
|
|
# ======================================================================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
|
async def test_create_mcp_server(mock_get_client, server, default_user):
|
|
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
|
|
from letta.settings import tool_settings
|
|
|
|
if tool_settings.mcp_read_from_config:
|
|
return
|
|
|
|
# create mock client with required methods
|
|
mock_client = AsyncMock()
|
|
mock_client.connect_to_server = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(
|
|
return_value=[
|
|
MCPTool(
|
|
name="get_simple_price",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"ids": {"type": "string"},
|
|
"vs_currencies": {"type": "string"},
|
|
"include_market_cap": {"type": "boolean"},
|
|
"include_24hr_vol": {"type": "boolean"},
|
|
"include_24hr_change": {"type": "boolean"},
|
|
},
|
|
"required": ["ids", "vs_currencies"],
|
|
"additionalProperties": False,
|
|
},
|
|
)
|
|
]
|
|
)
|
|
mock_client.execute_tool = AsyncMock(
|
|
return_value=(
|
|
'{"bitcoin": {"usd": 50000, "usd_market_cap": 900000000000, "usd_24h_vol": 30000000000, "usd_24h_change": 2.5}}',
|
|
True,
|
|
)
|
|
)
|
|
mock_get_client.return_value = mock_client
|
|
|
|
# Test with a valid StdioServerConfig
|
|
server_config = StdioServerConfig(
|
|
server_name="test_server", type=MCPServerType.STDIO, command="echo 'test'", args=["arg1", "arg2"], env={"ENV1": "value1"}
|
|
)
|
|
mcp_server = MCPServer(server_name="test_server", server_type=MCPServerType.STDIO, stdio_config=server_config)
|
|
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
|
print(created_server)
|
|
assert created_server.server_name == server_config.server_name
|
|
assert created_server.server_type == server_config.type
|
|
|
|
# Test with a valid SSEServerConfig
|
|
mcp_server_name = "coingecko"
|
|
server_url = "https://mcp.api.coingecko.com/sse"
|
|
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
|
mcp_sse_server = MCPServer(server_name=mcp_server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
|
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_sse_server, actor=default_user)
|
|
print(created_server)
|
|
assert created_server.server_name == mcp_server_name
|
|
assert created_server.server_type == MCPServerType.SSE
|
|
|
|
# list mcp servers
|
|
servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
|
|
print(servers)
|
|
assert len(servers) > 0, "No MCP servers found"
|
|
|
|
# list tools from sse server
|
|
tools = await server.mcp_manager.list_mcp_server_tools(created_server.server_name, actor=default_user)
|
|
print(tools)
|
|
|
|
# call a tool from the sse server
|
|
tool_name = "get_simple_price"
|
|
tool_args = {
|
|
"ids": "bitcoin",
|
|
"vs_currencies": "usd",
|
|
"include_market_cap": True,
|
|
"include_24hr_vol": True,
|
|
"include_24hr_change": True,
|
|
}
|
|
result = await server.mcp_manager.execute_mcp_server_tool(
|
|
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user, environment_variables={}
|
|
)
|
|
print(result)
|
|
|
|
# add a tool
|
|
tool = await server.mcp_manager.add_tool_from_mcp_server(created_server.server_name, tool_name, actor=default_user)
|
|
print(tool)
|
|
assert tool.name == tool_name
|
|
assert f"mcp:{created_server.server_name}" in tool.tags, f"Expected tag {f'mcp:{created_server.server_name}'}, got {tool.tags}"
|
|
print("TAGS", tool.tags)
|
|
|
|
|
|
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
|
async def test_create_mcp_server_with_tools(mock_get_client, server, default_user):
|
|
"""Test that creating an MCP server automatically syncs and persists its tools."""
|
|
from letta.functions.mcp_client.types import MCPToolHealth
|
|
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig
|
|
from letta.settings import tool_settings
|
|
|
|
if tool_settings.mcp_read_from_config:
|
|
return
|
|
|
|
# Create mock tools with different health statuses
|
|
mock_tools = [
|
|
MCPTool(
|
|
name="valid_tool_1",
|
|
description="A valid tool",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"param1": {"type": "string"},
|
|
},
|
|
"required": ["param1"],
|
|
},
|
|
health=MCPToolHealth(status="VALID", reasons=[]),
|
|
),
|
|
MCPTool(
|
|
name="valid_tool_2",
|
|
description="Another valid tool",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"param2": {"type": "number"},
|
|
},
|
|
},
|
|
health=MCPToolHealth(status="VALID", reasons=[]),
|
|
),
|
|
MCPTool(
|
|
name="invalid_tool",
|
|
description="An invalid tool that should be skipped",
|
|
inputSchema={
|
|
"type": "invalid_type", # Invalid schema
|
|
},
|
|
health=MCPToolHealth(status="INVALID", reasons=["Invalid schema type"]),
|
|
),
|
|
MCPTool(
|
|
name="warning_tool",
|
|
description="A tool with warnings but should still be persisted",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {},
|
|
},
|
|
health=MCPToolHealth(status="WARNING", reasons=["No properties defined"]),
|
|
),
|
|
]
|
|
|
|
# Create mock client
|
|
mock_client = AsyncMock()
|
|
mock_client.connect_to_server = AsyncMock()
|
|
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
|
mock_client.cleanup = AsyncMock()
|
|
mock_get_client.return_value = mock_client
|
|
|
|
# Create MCP server config
|
|
server_name = f"test_server_{uuid.uuid4().hex[:8]}"
|
|
server_url = "https://test-with-tools.example.com/sse"
|
|
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
|
|
|
# Create server with tools using the new method
|
|
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
|
|
|
|
# Verify server was created
|
|
assert created_server.server_name == server_name
|
|
assert created_server.server_type == MCPServerType.SSE
|
|
assert created_server.server_url == server_url
|
|
|
|
# Verify tools were persisted (all except the invalid one)
|
|
# Get all tools and filter by checking metadata
|
|
all_tools = await server.tool_manager.list_tools_async(
|
|
actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool", "invalid_tool"]
|
|
)
|
|
|
|
# Filter tools that belong to our MCP server
|
|
persisted_tools = [
|
|
tool
|
|
for tool in all_tools
|
|
if tool.metadata_
|
|
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
|
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
|
]
|
|
|
|
# Should have 3 tools (2 valid + 1 warning, but not the invalid one)
|
|
assert len(persisted_tools) == 3, f"Expected 3 tools, got {len(persisted_tools)}"
|
|
|
|
# Check tool names
|
|
tool_names = {tool.name for tool in persisted_tools}
|
|
assert "valid_tool_1" in tool_names
|
|
assert "valid_tool_2" in tool_names
|
|
assert "warning_tool" in tool_names
|
|
assert "invalid_tool" not in tool_names # Invalid tool should be filtered out
|
|
|
|
# Verify each tool has correct metadata
|
|
for tool in persisted_tools:
|
|
assert tool.metadata_ is not None
|
|
assert MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
|
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == server_name
|
|
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_id"] == created_server.id
|
|
assert tool.tool_type == ToolType.EXTERNAL_MCP
|
|
|
|
# Clean up - delete the server
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
|
|
# Verify tools were also deleted (cascade) by trying to get them again
|
|
remaining_tools = await server.tool_manager.list_tools_async(actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool"])
|
|
|
|
# Filter to see if any still belong to our deleted server
|
|
remaining_mcp_tools = [
|
|
tool
|
|
for tool in remaining_tools
|
|
if tool.metadata_
|
|
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
|
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
|
]
|
|
assert len(remaining_mcp_tools) == 0, "Tools should be deleted when server is deleted"
|
|
|
|
|
|
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
|
async def test_create_mcp_server_with_tools_connection_failure(mock_get_client, server, default_user):
|
|
"""Test that MCP server creation succeeds even when tool sync fails (optimistic approach)."""
|
|
from letta.schemas.mcp import MCPServer, MCPServerType
|
|
from letta.settings import tool_settings
|
|
|
|
if tool_settings.mcp_read_from_config:
|
|
return
|
|
|
|
# Create mock client that fails to connect
|
|
mock_client = AsyncMock()
|
|
mock_client.connect_to_server = AsyncMock(side_effect=Exception("Connection failed"))
|
|
mock_client.cleanup = AsyncMock()
|
|
mock_get_client.return_value = mock_client
|
|
|
|
# Create MCP server config
|
|
server_name = f"test_server_fail_{uuid.uuid4().hex[:8]}"
|
|
server_url = "https://test-fail.example.com/sse"
|
|
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
|
|
|
# Create server with tools - should succeed despite connection failure
|
|
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
|
|
|
|
# Verify server was created successfully
|
|
assert created_server.server_name == server_name
|
|
assert created_server.server_type == MCPServerType.SSE
|
|
assert created_server.server_url == server_url
|
|
|
|
# Verify no tools were persisted (due to connection failure)
|
|
# Try to get tools by the names we would have expected
|
|
all_tools = await server.tool_manager.list_tools_async(
|
|
actor=default_user,
|
|
names=["tool1", "tool2", "tool3"], # Generic names since we don't know what tools would have been listed
|
|
)
|
|
|
|
# Filter to see if any belong to our server (there shouldn't be any)
|
|
persisted_tools = [
|
|
tool
|
|
for tool in all_tools
|
|
if tool.metadata_
|
|
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
|
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
|
]
|
|
assert len(persisted_tools) == 0, "No tools should be persisted when connection fails"
|
|
|
|
# Clean up
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
|
|
|
|
async def test_get_mcp_servers_by_ids(server, default_user):
|
|
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
|
|
from letta.settings import tool_settings
|
|
|
|
if tool_settings.mcp_read_from_config:
|
|
return
|
|
|
|
# Create multiple MCP servers for testing
|
|
servers_data = [
|
|
{
|
|
"name": "test_server_1",
|
|
"config": StdioServerConfig(
|
|
server_name="test_server_1", type=MCPServerType.STDIO, command="echo 'test1'", args=["arg1"], env={"ENV1": "value1"}
|
|
),
|
|
"type": MCPServerType.STDIO,
|
|
},
|
|
{
|
|
"name": "test_server_2",
|
|
"config": SSEServerConfig(server_name="test_server_2", server_url="https://test2.example.com/sse"),
|
|
"type": MCPServerType.SSE,
|
|
},
|
|
{
|
|
"name": "test_server_3",
|
|
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/mcp"),
|
|
"type": MCPServerType.STREAMABLE_HTTP,
|
|
},
|
|
]
|
|
|
|
created_servers = []
|
|
for server_data in servers_data:
|
|
if server_data["type"] == MCPServerType.STDIO:
|
|
mcp_server = MCPServer(server_name=server_data["name"], server_type=server_data["type"], stdio_config=server_data["config"])
|
|
else:
|
|
mcp_server = MCPServer(
|
|
server_name=server_data["name"], server_type=server_data["type"], server_url=server_data["config"].server_url
|
|
)
|
|
|
|
created = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
|
created_servers.append(created)
|
|
|
|
# Test fetching multiple servers by IDs
|
|
server_ids = [s.id for s in created_servers]
|
|
fetched_servers = await server.mcp_manager.get_mcp_servers_by_ids(server_ids, actor=default_user)
|
|
|
|
assert len(fetched_servers) == len(created_servers)
|
|
fetched_ids = {s.id for s in fetched_servers}
|
|
expected_ids = {s.id for s in created_servers}
|
|
assert fetched_ids == expected_ids
|
|
|
|
# Test fetching subset of servers
|
|
subset_ids = server_ids[:2]
|
|
subset_servers = await server.mcp_manager.get_mcp_servers_by_ids(subset_ids, actor=default_user)
|
|
assert len(subset_servers) == 2
|
|
assert all(s.id in subset_ids for s in subset_servers)
|
|
|
|
# Test fetching with empty list
|
|
empty_result = await server.mcp_manager.get_mcp_servers_by_ids([], actor=default_user)
|
|
assert empty_result == []
|
|
|
|
# Test fetching with non-existent ID mixed with valid IDs
|
|
mixed_ids = [server_ids[0], "non-existent-id", server_ids[1]]
|
|
mixed_result = await server.mcp_manager.get_mcp_servers_by_ids(mixed_ids, actor=default_user)
|
|
# Should only return the existing servers
|
|
assert len(mixed_result) == 2
|
|
assert all(s.id in server_ids for s in mixed_result)
|
|
|
|
# Test that servers from different organizations are not returned
|
|
# This would require creating another user/org, but for now we'll just verify
|
|
# that the function respects the actor's organization
|
|
all_servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
|
|
all_server_ids = [s.id for s in all_servers]
|
|
bulk_fetched = await server.mcp_manager.get_mcp_servers_by_ids(all_server_ids, actor=default_user)
|
|
|
|
# All fetched servers should belong to the same organization
|
|
assert all(s.organization_id == default_user.organization_id for s in bulk_fetched)
|
|
|
|
|
|
# Additional MCPManager OAuth session tests
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_deletion_cascades_oauth_sessions(server, default_organization, default_user):
|
|
"""Deleting an MCP server deletes associated OAuth sessions (same user + URL)."""
|
|
|
|
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
|
|
|
|
test_server_url = "https://test.example.com/mcp"
|
|
|
|
# Create orphaned OAuth sessions (no server id) for same user and URL
|
|
created_session_ids: list[str] = []
|
|
for i in range(3):
|
|
session = await server.mcp_manager.create_oauth_session(
|
|
MCPOAuthSessionCreate(
|
|
server_url=test_server_url,
|
|
server_name=f"test_mcp_server_{i}",
|
|
user_id=default_user.id,
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
created_session_ids.append(session.id)
|
|
|
|
# Create the MCP server with the same URL
|
|
created_server = await server.mcp_manager.create_mcp_server(
|
|
PydanticMCPServer(
|
|
server_name=f"test_mcp_server_{str(uuid.uuid4().hex[:8])}", # ensure unique name
|
|
server_type=MCPServerType.SSE,
|
|
server_url=test_server_url,
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Now delete the server via manager
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
|
|
# Verify all sessions are gone
|
|
for sid in created_session_ids:
|
|
session = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user)
|
|
assert session is None, f"OAuth session {sid} should be deleted"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_oauth_sessions_with_different_url_persist(server, default_organization, default_user):
|
|
"""Sessions with different URL should not be deleted when deleting the server for another URL."""
|
|
|
|
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
|
|
|
|
server_url = "https://test.example.com/mcp"
|
|
other_url = "https://other.example.com/mcp"
|
|
|
|
# Create a session for other_url (should persist)
|
|
other_session = await server.mcp_manager.create_oauth_session(
|
|
MCPOAuthSessionCreate(
|
|
server_url=other_url,
|
|
server_name="standalone_oauth",
|
|
user_id=default_user.id,
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Create the MCP server at server_url
|
|
created_server = await server.mcp_manager.create_mcp_server(
|
|
PydanticMCPServer(
|
|
server_name=f"test_mcp_server_{str(uuid.uuid4().hex[:8])}",
|
|
server_type=MCPServerType.SSE,
|
|
server_url=server_url,
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Delete the server at server_url
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
|
|
# Verify the session at other_url still exists
|
|
persisted = await server.mcp_manager.get_oauth_session_by_id(other_session.id, actor=default_user)
|
|
assert persisted is not None, "OAuth session with different URL should persist"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_creation_links_orphaned_sessions(server, default_organization, default_user):
|
|
"""Creating a server should link any existing orphaned sessions (same user + URL)."""
|
|
|
|
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
|
|
|
|
server_url = "https://test-atomic-create.example.com/mcp"
|
|
|
|
# Pre-create orphaned sessions (no server_id) for same user + URL
|
|
orphaned_ids: list[str] = []
|
|
for i in range(3):
|
|
session = await server.mcp_manager.create_oauth_session(
|
|
MCPOAuthSessionCreate(
|
|
server_url=server_url,
|
|
server_name=f"atomic_session_{i}",
|
|
user_id=default_user.id,
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
orphaned_ids.append(session.id)
|
|
|
|
# Create server
|
|
created_server = await server.mcp_manager.create_mcp_server(
|
|
PydanticMCPServer(
|
|
server_name=f"test_atomic_server_{str(uuid.uuid4().hex[:8])}",
|
|
server_type=MCPServerType.SSE,
|
|
server_url=server_url,
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Sessions should still be retrievable via manager API
|
|
for sid in orphaned_ids:
|
|
s = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user)
|
|
assert s is not None
|
|
|
|
# Indirect verification: deleting the server removes sessions for that URL+user
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
for sid in orphaned_ids:
|
|
assert await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user) is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_delete_removes_all_sessions_for_url_and_user(server, default_organization, default_user):
|
|
"""Deleting a server removes both linked and orphaned sessions for same user+URL."""
|
|
|
|
from letta.schemas.mcp import MCPOAuthSessionCreate, MCPServer as PydanticMCPServer, MCPServerType
|
|
|
|
server_url = "https://test-atomic-cleanup.example.com/mcp"
|
|
|
|
# Create orphaned session
|
|
orphaned = await server.mcp_manager.create_oauth_session(
|
|
MCPOAuthSessionCreate(
|
|
server_url=server_url,
|
|
server_name="orphaned",
|
|
user_id=default_user.id,
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Create server
|
|
created_server = await server.mcp_manager.create_mcp_server(
|
|
PydanticMCPServer(
|
|
server_name=f"cleanup_server_{str(uuid.uuid4().hex[:8])}",
|
|
server_type=MCPServerType.SSE,
|
|
server_url=server_url,
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
|
|
# Delete server
|
|
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
|
|
|
# Both orphaned and any linked sessions for that URL+user should be gone
|
|
assert await server.mcp_manager.get_oauth_session_by_id(orphaned.id, actor=default_user) is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_server_resync_tools(server, default_user, default_organization):
|
|
"""Test that resyncing MCP server tools correctly handles added, deleted, and updated tools."""
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
|
|
from letta.schemas.mcp import MCPServer as PydanticMCPServer, MCPServerType
|
|
from letta.schemas.tool import ToolCreate
|
|
|
|
# Create MCP server
|
|
mcp_server = await server.mcp_manager.create_mcp_server(
|
|
PydanticMCPServer(
|
|
server_name=f"test_resync_{uuid.uuid4().hex[:8]}",
|
|
server_type=MCPServerType.SSE,
|
|
server_url="https://test-resync.example.com/mcp",
|
|
organization_id=default_organization.id,
|
|
),
|
|
actor=default_user,
|
|
)
|
|
mcp_server_id = mcp_server.id
|
|
|
|
try:
|
|
# Create initial persisted tools (simulating previously added tools)
|
|
# Use sync method like in the existing mcp_tool fixture
|
|
tool1_create = ToolCreate.from_mcp(
|
|
mcp_server_name=mcp_server.server_name,
|
|
mcp_tool=MCPTool(
|
|
name="tool1",
|
|
description="Tool 1",
|
|
inputSchema={"type": "object", "properties": {"param1": {"type": "string"}}},
|
|
),
|
|
)
|
|
tool1 = await server.tool_manager.create_or_update_mcp_tool_async(
|
|
tool_create=tool1_create,
|
|
mcp_server_name=mcp_server.server_name,
|
|
mcp_server_id=mcp_server_id,
|
|
actor=default_user,
|
|
)
|
|
|
|
tool2_create = ToolCreate.from_mcp(
|
|
mcp_server_name=mcp_server.server_name,
|
|
mcp_tool=MCPTool(
|
|
name="tool2",
|
|
description="Tool 2 to be deleted",
|
|
inputSchema={"type": "object", "properties": {"param2": {"type": "number"}}},
|
|
),
|
|
)
|
|
tool2 = await server.tool_manager.create_or_update_mcp_tool_async(
|
|
tool_create=tool2_create,
|
|
mcp_server_name=mcp_server.server_name,
|
|
mcp_server_id=mcp_server_id,
|
|
actor=default_user,
|
|
)
|
|
|
|
# Mock the list_mcp_server_tools to return updated tools from server
|
|
# tool1 is updated, tool2 is deleted, tool3 is added
|
|
updated_tools = [
|
|
MCPTool(
|
|
name="tool1",
|
|
description="Tool 1 Updated",
|
|
inputSchema={"type": "object", "properties": {"param1": {"type": "string"}, "param1b": {"type": "boolean"}}},
|
|
health=MCPToolHealth(status="VALID", reasons=[]),
|
|
),
|
|
MCPTool(
|
|
name="tool3",
|
|
description="Tool 3 New",
|
|
inputSchema={"type": "object", "properties": {"param3": {"type": "array"}}},
|
|
health=MCPToolHealth(status="VALID", reasons=[]),
|
|
),
|
|
]
|
|
|
|
with patch.object(server.mcp_manager, "list_mcp_server_tools", new_callable=AsyncMock) as mock_list_tools:
|
|
mock_list_tools.return_value = updated_tools
|
|
|
|
# Run resync
|
|
result = await server.mcp_manager.resync_mcp_server_tools(
|
|
mcp_server_name=mcp_server.server_name,
|
|
actor=default_user,
|
|
)
|
|
|
|
# Verify the resync result
|
|
assert len(result.deleted) == 1
|
|
assert "tool2" in result.deleted
|
|
|
|
assert len(result.updated) == 1
|
|
assert "tool1" in result.updated
|
|
|
|
assert len(result.added) == 1
|
|
assert "tool3" in result.added
|
|
|
|
# Verify tool2 was actually deleted
|
|
try:
|
|
deleted_tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool2.id, actor=default_user)
|
|
assert False, "Tool2 should have been deleted"
|
|
except Exception:
|
|
pass # Expected - tool should be deleted
|
|
|
|
# Verify tool1 was updated with new schema
|
|
updated_tool1 = await server.tool_manager.get_tool_by_id_async(tool_id=tool1.id, actor=default_user)
|
|
assert "param1b" in updated_tool1.json_schema["parameters"]["properties"]
|
|
|
|
# Verify tool3 was added
|
|
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["tool3"])
|
|
assert len(tools) == 1
|
|
assert tools[0].name == "tool3"
|
|
|
|
finally:
|
|
# Clean up
|
|
await server.mcp_manager.delete_mcp_server_by_id(mcp_server_id, actor=default_user)
|