feat: exception handling middleware for sandbox_configs + identities + tools (#5143)

This commit is contained in:
Sarah Wooders
2025-10-06 23:52:05 -07:00
committed by Caren Thomas
parent 307c85ca9a
commit 324933edd3
6 changed files with 307 additions and 503 deletions

View File

@@ -109,6 +109,18 @@ class LettaMCPError(LettaError):
"""Base error for MCP-related issues."""
class LettaInvalidMCPSchemaError(LettaMCPError):
"""Error raised when an invalid MCP schema is provided."""
def __init__(self, server_name: str, mcp_tool_name: str, reasons: List[str]):
details = {"server_name": server_name, "mcp_tool_name": mcp_tool_name, "reasons": reasons}
super().__init__(
message=f"MCP tool {mcp_tool_name} has an invalid schema and cannot be attached - reasons: {reasons}",
code=ErrorCode.INVALID_ARGUMENT,
details=details,
)
class LettaMCPConnectionError(LettaMCPError):
"""Error raised when unable to connect to MCP server."""

View File

@@ -27,6 +27,7 @@ from letta.errors import (
BedrockPermissionError,
LettaAgentNotFoundError,
LettaInvalidArgumentError,
LettaInvalidMCPSchemaError,
LettaMCPConnectionError,
LettaMCPTimeoutError,
LettaToolCreateError,
@@ -264,6 +265,7 @@ def create_application() -> "FastAPI":
# 408 Timeout errors
app.add_exception_handler(LettaMCPTimeoutError, _error_handler_408)
app.add_exception_handler(LettaInvalidMCPSchemaError, _error_handler_400)
# 409 Conflict errors
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Literal, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from fastapi import APIRouter, Body, Depends, Header, Query
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.schemas.agent import AgentState
@@ -39,26 +39,19 @@ async def list_identities(
"""
Get a list of all identities in the database
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
identities = await server.identity_manager.list_identities_async(
name=name,
project_id=project_id,
identifier_key=identifier_key,
identity_type=identity_type,
before=before,
after=after,
limit=limit,
ascending=(order == "asc"),
actor=actor,
)
except HTTPException:
raise
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
identities = await server.identity_manager.list_identities_async(
name=name,
project_id=project_id,
identifier_key=identifier_key,
identity_type=identity_type,
before=before,
after=after,
limit=limit,
ascending=(order == "asc"),
actor=actor,
)
return identities
@@ -70,15 +63,11 @@ async def count_identities(
"""
Get count of all identities for a user
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.size_async(actor=actor)
except NoResultFound:
return 0
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
@@ -87,11 +76,8 @@ async def retrieve_identity(
server: "SyncServer" = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor)
@router.post("/", tags=["identities"], response_model=Identity, operation_id="create_identity")
@@ -103,21 +89,8 @@ async def create_identity(
None, alias="X-Project", description="The project slug to associate with the identity (cloud only)."
), # Only handled by next js middleware
):
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.create_identity_async(identity=identity, actor=actor)
except HTTPException:
raise
except UniqueConstraintViolationError:
if identity.project_id:
raise HTTPException(
status_code=409,
detail=f"An identity with identifier key {identity.identifier_key} already exists for project {identity.project_id}",
)
else:
raise HTTPException(status_code=409, detail=f"An identity with identifier key {identity.identifier_key} already exists")
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.create_identity_async(identity=identity, actor=actor)
@router.put("/", tags=["identities"], response_model=Identity, operation_id="upsert_identity")
@@ -129,15 +102,8 @@ async def upsert_identity(
None, alias="X-Project", description="The project slug to associate with the identity (cloud only)."
), # Only handled by next js middleware
):
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor)
except HTTPException:
raise
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor)
@router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity")
@@ -147,15 +113,8 @@ async def modify_identity(
server: "SyncServer" = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor)
except HTTPException:
raise
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor)
@router.put("/{identity_id}/properties", tags=["identities"], operation_id="upsert_identity_properties")
@@ -165,15 +124,8 @@ async def upsert_identity_properties(
server: "SyncServer" = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor)
except HTTPException:
raise
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor)
@router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity")
@@ -185,15 +137,8 @@ async def delete_identity(
"""
Delete an identity by its identifier key
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor)
except HTTPException:
raise
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor)
@router.get("/{identity_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_identity")
@@ -218,20 +163,15 @@ async def list_agents_for_identity(
"""
Get all agents associated with the specified identity.
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.list_agents_for_identity_async(
identity_id=identity_id,
before=before,
after=after,
limit=limit,
ascending=(order == "asc"),
actor=actor,
)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=f"Identity with id={identity_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.list_agents_for_identity_async(
identity_id=identity_id,
before=before,
after=after,
limit=limit,
ascending=(order == "asc"),
actor=actor,
)
@router.get("/{identity_id}/blocks", response_model=List[Block], operation_id="list_blocks_for_identity")
@@ -256,17 +196,12 @@ async def list_blocks_for_identity(
"""
Get all blocks associated with the specified identity.
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.list_blocks_for_identity_async(
identity_id=identity_id,
before=before,
after=after,
limit=limit,
ascending=(order == "asc"),
actor=actor,
)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=f"Identity with id={identity_id} not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
return await server.identity_manager.list_blocks_for_identity_async(
identity_id=identity_id,
before=before,
after=after,
limit=limit,
ascending=(order == "asc"),
actor=actor,
)

View File

@@ -2,8 +2,9 @@ import os
import shutil
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from letta.errors import LettaInvalidArgumentError
from letta.log import get_logger
from letta.schemas.enums import SandboxType
from letta.schemas.environment_variables import (
@@ -68,9 +69,8 @@ async def create_custom_local_sandbox_config(
"""
# Ensure the incoming config is of type LOCAL
if local_sandbox_config.type != SandboxType.LOCAL:
raise HTTPException(
status_code=400,
detail=f"Provided config must be of type '{SandboxType.LOCAL.value}'.",
raise LettaInvalidArgumentError(
f"Provided config must be of type '{SandboxType.LOCAL.value}'.", argument_name="local_sandbox_config.type"
)
# Retrieve the user (actor)
@@ -138,25 +138,16 @@ async def force_recreate_local_sandbox_venv(
# Check if venv exists, and delete if necessary
if os.path.isdir(venv_path):
try:
shutil.rmtree(venv_path)
logger.info(f"Deleted existing virtual environment at: {venv_path}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete existing venv: {e}")
shutil.rmtree(venv_path)
logger.info(f"Deleted existing virtual environment at: {venv_path}")
# Recreate the virtual environment
try:
create_venv_for_local_sandbox(sandbox_dir_path=sandbox_dir, venv_path=str(venv_path), env=os.environ.copy(), force_recreate=True)
logger.info(f"Successfully recreated virtual environment at: {venv_path}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to recreate venv: {e}")
create_venv_for_local_sandbox(sandbox_dir_path=sandbox_dir, venv_path=str(venv_path), env=os.environ.copy(), force_recreate=True)
logger.info(f"Successfully recreated virtual environment at: {venv_path}")
# Install pip requirements
try:
install_pip_requirements_for_sandbox(local_configs=local_configs, env=os.environ.copy())
logger.info(f"Successfully installed pip requirements for venv at: {venv_path}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to install pip requirements: {e}")
install_pip_requirements_for_sandbox(local_configs=local_configs, env=os.environ.copy())
logger.info(f"Successfully installed pip requirements for venv at: {venv_path}")
return sbx_config

View File

@@ -7,7 +7,14 @@ from httpx import ConnectError, HTTPStatusError
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from letta.errors import LettaToolCreateError, LettaToolNameConflictError
from letta.errors import (
LettaInvalidArgumentError,
LettaInvalidMCPSchemaError,
LettaMCPConnectionError,
LettaMCPTimeoutError,
LettaToolCreateError,
LettaToolNameConflictError,
)
from letta.functions.functions import derive_openai_json_schema
from letta.functions.mcp_client.exceptions import MCPTimeoutError
from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
@@ -70,78 +77,74 @@ async def count_tools(
"""
Get a count of all tools available to agents belonging to the org of the user.
"""
try:
# Helper function to parse tool types - supports both repeated params and comma-separated values
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_types_input is None:
return None
# Flatten any comma-separated values and validate against ToolType enum
flattened_types = []
for item in tool_types_input:
# Split by comma in case user provided comma-separated values
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
flattened_types.extend(types_in_item)
# Helper function to parse tool types - supports both repeated params and comma-separated values
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_types_input is None:
return None
# Validate each type against the ToolType enum
valid_types = []
valid_values = [tt.value for tt in ToolType]
# Flatten any comma-separated values and validate against ToolType enum
flattened_types = []
for item in tool_types_input:
# Split by comma in case user provided comma-separated values
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
flattened_types.extend(types_in_item)
for tool_type in flattened_types:
if tool_type not in valid_values:
raise HTTPException(
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
)
valid_types.append(tool_type)
# Validate each type against the ToolType enum
valid_types = []
valid_values = [tt.value for tt in ToolType]
return valid_types if valid_types else None
for tool_type in flattened_types:
if tool_type not in valid_values:
raise HTTPException(status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}")
valid_types.append(tool_type)
# Parse and validate tool types (same logic as list_tools)
tool_types_str = parse_tool_types(tool_types)
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
return valid_types if valid_types else None
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Parse and validate tool types (same logic as list_tools)
tool_types_str = parse_tool_types(tool_types)
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
# Combine single name with names list for unified processing (same logic as list_tools)
combined_names = []
if name is not None:
combined_names.append(name)
if names is not None:
combined_names.extend(names)
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Use None if no names specified, otherwise use the combined list
final_names = combined_names if combined_names else None
# Combine single name with names list for unified processing (same logic as list_tools)
combined_names = []
if name is not None:
combined_names.append(name)
if names is not None:
combined_names.extend(names)
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_ids_input is None:
return None
# Use None if no names specified, otherwise use the combined list
final_names = combined_names if combined_names else None
# Flatten any comma-separated values
flattened_ids = []
for item in tool_ids_input:
# Split by comma in case user provided comma-separated values
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
flattened_ids.extend(ids_in_item)
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_ids_input is None:
return None
return flattened_ids if flattened_ids else None
# Flatten any comma-separated values
flattened_ids = []
for item in tool_ids_input:
# Split by comma in case user provided comma-separated values
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
flattened_ids.extend(ids_in_item)
# Parse tool IDs (same logic as list_tools)
final_tool_ids = parse_tool_ids(tool_ids)
return flattened_ids if flattened_ids else None
# Get the count of tools using unified query
return await server.tool_manager.count_tools_async(
actor=actor,
tool_types=tool_types_str,
exclude_tool_types=exclude_tool_types_str,
names=final_names,
tool_ids=final_tool_ids,
search=search,
return_only_letta_tools=return_only_letta_tools,
exclude_letta_tools=exclude_letta_tools,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Parse tool IDs (same logic as list_tools)
final_tool_ids = parse_tool_ids(tool_ids)
# Get the count of tools using unified query
return await server.tool_manager.count_tools_async(
actor=actor,
tool_types=tool_types_str,
exclude_tool_types=exclude_tool_types_str,
names=final_names,
tool_ids=final_tool_ids,
search=search,
return_only_letta_tools=return_only_letta_tools,
exclude_letta_tools=exclude_letta_tools,
)
@router.get("/{tool_id}", response_model=Tool, operation_id="retrieve_tool")
@@ -191,81 +194,77 @@ async def list_tools(
"""
Get a list of all tools available to agents.
"""
try:
# Helper function to parse tool types - supports both repeated params and comma-separated values
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_types_input is None:
return None
# Flatten any comma-separated values and validate against ToolType enum
flattened_types = []
for item in tool_types_input:
# Split by comma in case user provided comma-separated values
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
flattened_types.extend(types_in_item)
# Helper function to parse tool types - supports both repeated params and comma-separated values
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_types_input is None:
return None
# Validate each type against the ToolType enum
valid_types = []
valid_values = [tt.value for tt in ToolType]
# Flatten any comma-separated values and validate against ToolType enum
flattened_types = []
for item in tool_types_input:
# Split by comma in case user provided comma-separated values
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
flattened_types.extend(types_in_item)
for tool_type in flattened_types:
if tool_type not in valid_values:
raise HTTPException(
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
)
valid_types.append(tool_type)
# Validate each type against the ToolType enum
valid_types = []
valid_values = [tt.value for tt in ToolType]
return valid_types if valid_types else None
for tool_type in flattened_types:
if tool_type not in valid_values:
raise HTTPException(status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}")
valid_types.append(tool_type)
# Parse and validate tool types
tool_types_str = parse_tool_types(tool_types)
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
return valid_types if valid_types else None
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Parse and validate tool types
tool_types_str = parse_tool_types(tool_types)
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
# Combine single name with names list for unified processing
combined_names = []
if name is not None:
combined_names.append(name)
if names is not None:
combined_names.extend(names)
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Use None if no names specified, otherwise use the combined list
final_names = combined_names if combined_names else None
# Combine single name with names list for unified processing
combined_names = []
if name is not None:
combined_names.append(name)
if names is not None:
combined_names.extend(names)
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_ids_input is None:
return None
# Use None if no names specified, otherwise use the combined list
final_names = combined_names if combined_names else None
# Flatten any comma-separated values
flattened_ids = []
for item in tool_ids_input:
# Split by comma in case user provided comma-separated values
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
flattened_ids.extend(ids_in_item)
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
if tool_ids_input is None:
return None
return flattened_ids if flattened_ids else None
# Flatten any comma-separated values
flattened_ids = []
for item in tool_ids_input:
# Split by comma in case user provided comma-separated values
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
flattened_ids.extend(ids_in_item)
# Parse tool IDs
final_tool_ids = parse_tool_ids(tool_ids)
return flattened_ids if flattened_ids else None
# Get the list of tools using unified query
return await server.tool_manager.list_tools_async(
actor=actor,
before=before,
after=after,
limit=limit,
ascending=(order == "asc"),
tool_types=tool_types_str,
exclude_tool_types=exclude_tool_types_str,
names=final_names,
tool_ids=final_tool_ids,
search=search,
return_only_letta_tools=return_only_letta_tools,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Parse tool IDs
final_tool_ids = parse_tool_ids(tool_ids)
# Get the list of tools using unified query
return await server.tool_manager.list_tools_async(
actor=actor,
before=before,
after=after,
limit=limit,
ascending=(order == "asc"),
tool_types=tool_types_str,
exclude_tool_types=exclude_tool_types_str,
names=final_names,
tool_ids=final_tool_ids,
search=search,
return_only_letta_tools=return_only_letta_tools,
)
@router.post("/", response_model=Tool, operation_id="create_tool")
@@ -379,38 +378,18 @@ async def list_mcp_tools_by_server(
"""
Get a list of all tools for a specific MCP server
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_tools = await server.mcp_manager.list_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor)
return mcp_tools
except Exception as e:
if isinstance(e, ConnectError) or isinstance(e, ConnectionError):
raise HTTPException(
status_code=404,
detail={
"code": "MCPListToolsError",
"message": str(e),
"mcp_server_name": mcp_server_name,
},
)
if isinstance(e, HTTPStatusError):
raise HTTPException(
status_code=401,
detail={
"code": "MCPListToolsError",
"message": str(e),
"mcp_server_name": mcp_server_name,
},
)
except (ConnectError, ConnectionError) as e:
raise LettaMCPConnectionError(str(e), server_name=mcp_server_name)
except HTTPStatusError as e:
# HTTPStatusError from the MCP server likely means auth issue
if e.response.status_code == 401:
raise LettaMCPConnectionError(f"Authentication failed: {e}", server_name=mcp_server_name)
else:
raise HTTPException(
status_code=500,
detail={
"code": "MCPListToolsError",
"message": str(e),
"mcp_server_name": mcp_server_name,
},
)
raise LettaMCPConnectionError(f"HTTP error from MCP server: {e}", server_name=mcp_server_name)
@router.post("/mcp/servers/{mcp_server_name}/resync", operation_id="resync_mcp_server_tools")
@@ -430,29 +409,8 @@ async def resync_mcp_server_tools(
Returns a summary of changes made.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
result = await server.mcp_manager.resync_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor, agent_id=agent_id)
return result
except ValueError as e:
raise HTTPException(
status_code=404,
detail={
"code": "MCPServerNotFoundError",
"message": str(e),
"mcp_server_name": mcp_server_name,
},
)
except Exception as e:
logger.error(f"Unexpected error refreshing MCP server tools: {e}")
raise HTTPException(
status_code=404,
detail={
"code": "MCPRefreshError",
"message": f"Failed to refresh MCP server tools: {str(e)}",
"mcp_server_name": mcp_server_name,
},
)
result = await server.mcp_manager.resync_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor, agent_id=agent_id)
return result
@router.post("/mcp/servers/{mcp_server_name}/{mcp_tool_name}", response_model=Tool, operation_id="add_mcp_tool")
@@ -470,25 +428,8 @@ async def add_mcp_tool(
if tool_settings.mcp_read_from_config:
try:
available_tools = await server.get_tools_from_mcp_server(mcp_server_name=mcp_server_name)
except ValueError as e:
# ValueError means that the MCP server name doesn't exist
raise HTTPException(
status_code=400, # Bad Request
detail={
"code": "MCPServerNotFoundError",
"message": str(e),
"mcp_server_name": mcp_server_name,
},
)
except MCPTimeoutError as e:
raise HTTPException(
status_code=408, # Timeout
detail={
"code": "MCPTimeoutError",
"message": str(e),
"mcp_server_name": mcp_server_name,
},
)
raise LettaMCPTimeoutError(str(e), server_name=mcp_server_name)
# See if the tool is in the available list
mcp_tool = None
@@ -497,27 +438,18 @@ async def add_mcp_tool(
mcp_tool = tool
break
if not mcp_tool:
raise HTTPException(
status_code=400, # Bad Request
detail={
"code": "MCPToolNotFoundError",
"message": f"Tool {mcp_tool_name} not found in MCP server {mcp_server_name} - available tools: {', '.join([tool.name for tool in available_tools])}",
"mcp_tool_name": mcp_tool_name,
},
raise LettaInvalidArgumentError(
f"Tool {mcp_tool_name} not found in MCP server {mcp_server_name} - available tools: {', '.join([tool.name for tool in available_tools])}",
argument_name="mcp_tool_name",
)
# Check tool health - reject only INVALID tools
if mcp_tool.health:
if mcp_tool.health.status == "INVALID":
raise HTTPException(
status_code=400,
detail={
"code": "MCPToolSchemaInvalid",
"message": f"Tool {mcp_tool_name} has an invalid schema and cannot be attached",
"mcp_tool_name": mcp_tool_name,
"health_status": mcp_tool.health.status,
"reasons": mcp_tool.health.reasons,
},
raise LettaInvalidMCPSchemaError(
server_name=mcp_server_name,
mcp_tool_name=mcp_tool_name,
reasons=mcp_tool.health.reasons,
)
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
@@ -544,40 +476,27 @@ async def add_mcp_server_to_config(
"""
Add a new MCP server to the Letta MCP server config
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
if tool_settings.mcp_read_from_config:
# write to config file
return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
else:
# log to DB
# Check if stdio servers are disabled
if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio:
raise HTTPException(
status_code=400,
detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server",
)
if tool_settings.mcp_read_from_config:
# write to config file
return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
else:
# log to DB
# Check if stdio servers are disabled
if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio:
raise HTTPException(
status_code=400,
detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server",
)
# Create MCP server and optimistically sync tools
# The mcp_manager will handle encryption of sensitive fields
await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
# Create MCP server and optimistically sync tools
# The mcp_manager will handle encryption of sensitive fields
await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
# TODO: don't do this in the future (just return MCPServer)
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
return [server.to_config() for server in all_servers]
except UniqueConstraintViolationError:
# If server name already exists, throw 409 conflict error
raise HTTPException(
status_code=409,
detail={
"code": "MCPServerNameAlreadyExistsError",
"message": f"MCP server with name '{request.server_name}' already exists",
"server_name": request.server_name,
},
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
# TODO: don't do this in the future (just return MCPServer)
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
return [server.to_config() for server in all_servers]
@router.patch(
@@ -594,21 +513,15 @@ async def update_mcp_server(
"""
Update an existing MCP server configuration
"""
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
if tool_settings.mcp_read_from_config:
raise HTTPException(status_code=501, detail="Update not implemented for config file mode, config files to be deprecated.")
else:
updated_server = await server.mcp_manager.update_mcp_server_by_name(
mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor
)
return updated_server.to_config()
except HTTPException:
# Re-raise HTTP exceptions (like 404)
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
if tool_settings.mcp_read_from_config:
raise HTTPException(status_code=501, detail="Update not implemented for config file mode, config files to be deprecated.")
else:
updated_server = await server.mcp_manager.update_mcp_server_by_name(
mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor
)
return updated_server.to_config()
@router.delete(
@@ -660,32 +573,9 @@ async def test_mcp_server(
return {"status": "success", "tools": tools}
except ConnectionError as e:
raise HTTPException(
status_code=400,
detail={
"code": "MCPServerConnectionError",
"message": str(e),
"server_name": request.server_name,
},
)
raise LettaMCPConnectionError(str(e), server_name=request.server_name)
except MCPTimeoutError as e:
raise HTTPException(
status_code=408,
detail={
"code": "MCPTimeoutError",
"message": f"MCP server connection timed out: {str(e)}",
"server_name": request.server_name,
},
)
except Exception as e:
raise HTTPException(
status_code=500,
detail={
"code": "MCPServerTestError",
"message": f"Failed to test MCP server: {str(e)}",
"server_name": request.server_name,
},
)
raise LettaMCPTimeoutError(f"MCP server connection timed out: {str(e)}", server_name=request.server_name)
finally:
if client:
try:
@@ -802,18 +692,14 @@ async def generate_json_schema(
Generate a JSON schema from the given source code defining a function or class.
Supports both Python and TypeScript source code.
"""
try:
if request.source_type == "typescript":
from letta.functions.typescript_parser import derive_typescript_json_schema
if request.source_type == "typescript":
from letta.functions.typescript_parser import derive_typescript_json_schema
schema = derive_typescript_json_schema(source_code=request.code)
else:
# Default to Python for backwards compatibility
schema = derive_openai_json_schema(source_code=request.code)
return schema
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}")
schema = derive_typescript_json_schema(source_code=request.code)
else:
# Default to Python for backwards compatibility
schema = derive_openai_json_schema(source_code=request.code)
return schema
# TODO: @jnjpng move this and other models above to appropriate file for schemas
@@ -840,14 +726,9 @@ async def execute_mcp_tool(
# Get the MCP server by name
mcp_server = await server.mcp_manager.get_mcp_server(mcp_server_name, actor)
if not mcp_server:
raise HTTPException(
status_code=404,
detail={
"code": "MCPServerNotFound",
"message": f"MCP server '{mcp_server_name}' not found",
"server_name": mcp_server_name,
},
)
from letta.orm.errors import NoResultFound
raise NoResultFound(f"MCP server '{mcp_server_name}' not found")
# Create client and connect
server_config = mcp_server.to_config()
@@ -862,19 +743,6 @@ async def execute_mcp_tool(
"result": result,
"success": success,
}
except HTTPException:
raise
except Exception as e:
logger.warning(f"Error executing MCP tool: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"code": "MCPToolExecutionError",
"message": f"Failed to execute MCP tool: {str(e)}",
"server_name": mcp_server_name,
"tool_name": tool_name,
},
)
finally:
if client:
try:
@@ -944,70 +812,64 @@ async def generate_tool_from_prompt(
"""
Generate a tool from the given user prompt.
"""
response_data = None
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
llm_config = await server.get_cached_llm_config_async(actor=actor, handle=request.handle or "anthropic/claude-3-5-sonnet-20240620")
formatted_prompt = (
f"Generate a python function named {request.tool_name} using the instructions below "
+ (f"based on this starter code: \n\n```\n{request.starter_code}\n```\n\n" if request.starter_code else "\n")
+ (f"Note the following validation errors: \n{' '.join(request.validation_errors)}\n\n" if request.validation_errors else "\n")
+ f"Instructions: {request.prompt}"
)
llm_client = LLMClient.create(
provider_type=llm_config.model_endpoint_type,
actor=actor,
)
assert llm_client is not None
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
llm_config = await server.get_cached_llm_config_async(actor=actor, handle=request.handle or "anthropic/claude-3-5-sonnet-20240620")
formatted_prompt = (
f"Generate a python function named {request.tool_name} using the instructions below "
+ (f"based on this starter code: \n\n```\n{request.starter_code}\n```\n\n" if request.starter_code else "\n")
+ (f"Note the following validation errors: \n{' '.join(request.validation_errors)}\n\n" if request.validation_errors else "\n")
+ f"Instructions: {request.prompt}"
)
llm_client = LLMClient.create(
provider_type=llm_config.model_endpoint_type,
actor=actor,
)
assert llm_client is not None
assistant_message_ack = "Understood, I will respond with generated python source code and sample arguments that can be used to test the functionality once I receive the user prompt. I'm ready."
assistant_message_ack = "Understood, I will respond with generated python source code and sample arguments that can be used to test the functionality once I receive the user prompt. I'm ready."
input_messages = [
Message(role=MessageRole.system, content=[TextContent(text=get_system_text("memgpt_generate_tool"))]),
Message(role=MessageRole.assistant, content=[TextContent(text=assistant_message_ack)]),
Message(role=MessageRole.user, content=[TextContent(text=formatted_prompt)]),
]
input_messages = [
Message(role=MessageRole.system, content=[TextContent(text=get_system_text("memgpt_generate_tool"))]),
Message(role=MessageRole.assistant, content=[TextContent(text=assistant_message_ack)]),
Message(role=MessageRole.user, content=[TextContent(text=formatted_prompt)]),
]
tool = {
"name": "generate_tool",
"description": "This method generates the raw source code for a custom tool that can be attached to and agent for llm invocation.",
"parameters": {
"type": "object",
"properties": {
"raw_source_code": {"type": "string", "description": "The raw python source code of the custom tool."},
"sample_args_json": {
"type": "string",
"description": "The JSON dict that contains sample args for a test run of the python function. Key is the name of the function parameter and value is an example argument that is passed in.",
},
"pip_requirements_json": {
"type": "string",
"description": "Optional JSON dict that contains pip packages to be installed if needed by the source code. Key is the name of the pip package and value is the version number.",
},
tool = {
"name": "generate_tool",
"description": "This method generates the raw source code for a custom tool that can be attached to and agent for llm invocation.",
"parameters": {
"type": "object",
"properties": {
"raw_source_code": {"type": "string", "description": "The raw python source code of the custom tool."},
"sample_args_json": {
"type": "string",
"description": "The JSON dict that contains sample args for a test run of the python function. Key is the name of the function parameter and value is an example argument that is passed in.",
},
"pip_requirements_json": {
"type": "string",
"description": "Optional JSON dict that contains pip packages to be installed if needed by the source code. Key is the name of the pip package and value is the version number.",
},
"required": ["raw_source_code", "sample_args_json", "pip_requirements_json"],
},
}
request_data = llm_client.build_request_data(
AgentType.letta_v1_agent,
input_messages,
llm_config,
tools=[tool],
)
response_data = await llm_client.request_async(request_data, llm_config)
response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
output = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()]
return GenerateToolOutput(
tool=Tool(
name=request.tool_name,
source_type="python",
source_code=output["raw_source_code"],
pip_requirements=pip_requirements,
),
sample_args=json.loads(output["sample_args_json"]),
response=response.choices[0].message.content,
)
except Exception as e:
logger.error(f"Failed to generate tool: {str(e)}. Raw response: {response_data}")
raise HTTPException(status_code=500, detail=f"Failed to generate tool: {str(e)}")
"required": ["raw_source_code", "sample_args_json", "pip_requirements_json"],
},
}
request_data = llm_client.build_request_data(
AgentType.letta_v1_agent,
input_messages,
llm_config,
tools=[tool],
)
response_data = await llm_client.request_async(request_data, llm_config)
response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
output = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()]
return GenerateToolOutput(
tool=Tool(
name=request.tool_name,
source_type="python",
source_code=output["raw_source_code"],
pip_requirements=pip_requirements,
),
sample_args=json.loads(output["sample_args_json"]),
response=response.choices[0].message.content,
)

View File

@@ -163,12 +163,14 @@ async def test_add_mcp_tool_rejects_invalid_schemas():
# Should raise HTTPException for invalid schema
headers = HeaderParams(actor_id="test_user")
with pytest.raises(HTTPException) as exc_info:
from letta.errors import LettaInvalidMCPSchemaError
with pytest.raises(LettaInvalidMCPSchemaError) as exc_info:
await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, headers=headers)
assert exc_info.value.status_code == 400
assert "invalid schema" in exc_info.value.detail["message"].lower()
assert exc_info.value.detail["health_status"] == SchemaHealth.INVALID.value
assert "invalid schema" in exc_info.value.message.lower()
assert exc_info.value.details["mcp_tool_name"] == "test_tool"
assert exc_info.value.details["reasons"] == ["Missing 'type' at root level"]
def test_mcp_schema_healing_for_optional_fields():