diff --git a/letta/errors.py b/letta/errors.py index 1d753d35..082fe24a 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -109,6 +109,18 @@ class LettaMCPError(LettaError): """Base error for MCP-related issues.""" +class LettaInvalidMCPSchemaError(LettaMCPError): + """Error raised when an invalid MCP schema is provided.""" + + def __init__(self, server_name: str, mcp_tool_name: str, reasons: List[str]): + details = {"server_name": server_name, "mcp_tool_name": mcp_tool_name, "reasons": reasons} + super().__init__( + message=f"MCP tool {mcp_tool_name} has an invalid schema and cannot be attached - reasons: {reasons}", + code=ErrorCode.INVALID_ARGUMENT, + details=details, + ) + + class LettaMCPConnectionError(LettaMCPError): """Error raised when unable to connect to MCP server.""" diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 0db9ac85..34e0f711 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -27,6 +27,7 @@ from letta.errors import ( BedrockPermissionError, LettaAgentNotFoundError, LettaInvalidArgumentError, + LettaInvalidMCPSchemaError, LettaMCPConnectionError, LettaMCPTimeoutError, LettaToolCreateError, @@ -264,6 +265,7 @@ def create_application() -> "FastAPI": # 408 Timeout errors app.add_exception_handler(LettaMCPTimeoutError, _error_handler_408) + app.add_exception_handler(LettaInvalidMCPSchemaError, _error_handler_400) # 409 Conflict errors app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409) diff --git a/letta/server/rest_api/routers/v1/identities.py b/letta/server/rest_api/routers/v1/identities.py index 2a883304..76a28fb3 100644 --- a/letta/server/rest_api/routers/v1/identities.py +++ b/letta/server/rest_api/routers/v1/identities.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, List, Literal, Optional -from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from fastapi import APIRouter, Body, Depends, Header, Query from letta.orm.errors import NoResultFound, UniqueConstraintViolationError from letta.schemas.agent import AgentState @@ -39,26 +39,19 @@ async def list_identities( """ Get a list of all identities in the database """ - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - identities = await server.identity_manager.list_identities_async( - name=name, - project_id=project_id, - identifier_key=identifier_key, - identity_type=identity_type, - before=before, - after=after, - limit=limit, - ascending=(order == "asc"), - actor=actor, - ) - except HTTPException: - raise - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") + identities = await server.identity_manager.list_identities_async( + name=name, + project_id=project_id, + identifier_key=identifier_key, + identity_type=identity_type, + before=before, + after=after, + limit=limit, + ascending=(order == "asc"), + actor=actor, + ) return identities @@ -70,15 +63,11 @@ async def count_identities( """ Get count of all identities for a user """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) return await server.identity_manager.size_async(actor=actor) except NoResultFound: return 0 - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") @router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity") @@ -87,11 +76,8 @@ async def retrieve_identity( server: "SyncServer" = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), ): - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor) - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor) @router.post("/", tags=["identities"], response_model=Identity, operation_id="create_identity") @@ -103,21 +89,8 @@ async def create_identity( None, alias="X-Project", description="The project slug to associate with the identity (cloud only)." ), # Only handled by next js middleware ): - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.identity_manager.create_identity_async(identity=identity, actor=actor) - except HTTPException: - raise - except UniqueConstraintViolationError: - if identity.project_id: - raise HTTPException( - status_code=409, - detail=f"An identity with identifier key {identity.identifier_key} already exists for project {identity.project_id}", - ) - else: - raise HTTPException(status_code=409, detail=f"An identity with identifier key {identity.identifier_key} already exists") - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await server.identity_manager.create_identity_async(identity=identity, actor=actor) @router.put("/", tags=["identities"], response_model=Identity, operation_id="upsert_identity") @@ -129,15 +102,8 @@ async def upsert_identity( None, alias="X-Project", description="The project slug to associate with the identity (cloud only)." ), # Only handled by next js middleware ): - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor) - except HTTPException: - raise - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor) @router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity") @@ -147,15 +113,8 @@ async def modify_identity( server: "SyncServer" = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), ): - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor) - except HTTPException: - raise - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor) @router.put("/{identity_id}/properties", tags=["identities"], operation_id="upsert_identity_properties") @@ -165,15 +124,8 @@ async def upsert_identity_properties( server: "SyncServer" = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), ): - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor) - except HTTPException: - raise - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor) @router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity") @@ -185,15 +137,8 @@ async def delete_identity( """ Delete an identity by its identifier key """ - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor) - except HTTPException: - raise - except NoResultFound as e: - raise HTTPException(status_code=404, detail=str(e)) - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor) @router.get("/{identity_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_identity") @@ -218,20 +163,15 @@ async def list_agents_for_identity( """ Get all agents associated with the specified identity. """ - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.identity_manager.list_agents_for_identity_async( - identity_id=identity_id, - before=before, - after=after, - limit=limit, - ascending=(order == "asc"), - actor=actor, - ) - except NoResultFound as e: - raise HTTPException(status_code=404, detail=f"Identity with id={identity_id} not found") - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await server.identity_manager.list_agents_for_identity_async( + identity_id=identity_id, + before=before, + after=after, + limit=limit, + ascending=(order == "asc"), + actor=actor, + ) @router.get("/{identity_id}/blocks", response_model=List[Block], operation_id="list_blocks_for_identity") @@ -256,17 +196,12 @@ async def list_blocks_for_identity( """ Get all blocks associated with the specified identity. """ - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - return await server.identity_manager.list_blocks_for_identity_async( - identity_id=identity_id, - before=before, - after=after, - limit=limit, - ascending=(order == "asc"), - actor=actor, - ) - except NoResultFound as e: - raise HTTPException(status_code=404, detail=f"Identity with id={identity_id} not found") - except Exception as e: - raise HTTPException(status_code=500, detail=f"{e}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + return await server.identity_manager.list_blocks_for_identity_async( + identity_id=identity_id, + before=before, + after=after, + limit=limit, + ascending=(order == "asc"), + actor=actor, + ) diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py index 86b72369..94365ec3 100644 --- a/letta/server/rest_api/routers/v1/sandbox_configs.py +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -2,8 +2,9 @@ import os import shutil from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query +from letta.errors import LettaInvalidArgumentError from letta.log import get_logger from letta.schemas.enums import SandboxType from letta.schemas.environment_variables import ( @@ -68,9 +69,8 @@ async def create_custom_local_sandbox_config( """ # Ensure the incoming config is of type LOCAL if local_sandbox_config.type != SandboxType.LOCAL: - raise HTTPException( - status_code=400, - detail=f"Provided config must be of type '{SandboxType.LOCAL.value}'.", + raise LettaInvalidArgumentError( + f"Provided config must be of type '{SandboxType.LOCAL.value}'.", argument_name="local_sandbox_config.type" ) # Retrieve the user (actor) @@ -138,25 +138,16 @@ async def force_recreate_local_sandbox_venv( # Check if venv exists, and delete if necessary if os.path.isdir(venv_path): - try: - shutil.rmtree(venv_path) - logger.info(f"Deleted existing virtual environment at: {venv_path}") - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to delete existing venv: {e}") + shutil.rmtree(venv_path) + logger.info(f"Deleted existing virtual environment at: {venv_path}") # Recreate the virtual environment - try: - create_venv_for_local_sandbox(sandbox_dir_path=sandbox_dir, venv_path=str(venv_path), env=os.environ.copy(), force_recreate=True) - logger.info(f"Successfully recreated virtual environment at: {venv_path}") - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to recreate venv: {e}") + create_venv_for_local_sandbox(sandbox_dir_path=sandbox_dir, venv_path=str(venv_path), env=os.environ.copy(), force_recreate=True) + logger.info(f"Successfully recreated virtual environment at: {venv_path}") # Install pip requirements - try: - install_pip_requirements_for_sandbox(local_configs=local_configs, env=os.environ.copy()) - logger.info(f"Successfully installed pip requirements for venv at: {venv_path}") - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to install pip requirements: {e}") + install_pip_requirements_for_sandbox(local_configs=local_configs, env=os.environ.copy()) + logger.info(f"Successfully installed pip requirements for venv at: {venv_path}") return sbx_config diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 9822c889..87dd8532 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -7,7 +7,14 @@ from httpx import ConnectError, HTTPStatusError from pydantic import BaseModel, Field from starlette.responses import StreamingResponse -from letta.errors import LettaToolCreateError, LettaToolNameConflictError +from letta.errors import ( + LettaInvalidArgumentError, + LettaInvalidMCPSchemaError, + LettaMCPConnectionError, + LettaMCPTimeoutError, + LettaToolCreateError, + LettaToolNameConflictError, +) from letta.functions.functions import derive_openai_json_schema from letta.functions.mcp_client.exceptions import MCPTimeoutError from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig @@ -70,78 +77,74 @@ async def count_tools( """ Get a count of all tools available to agents belonging to the org of the user. """ - try: - # Helper function to parse tool types - supports both repeated params and comma-separated values - def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]: - if tool_types_input is None: - return None - # Flatten any comma-separated values and validate against ToolType enum - flattened_types = [] - for item in tool_types_input: - # Split by comma in case user provided comma-separated values - types_in_item = [t.strip() for t in item.split(",") if t.strip()] - flattened_types.extend(types_in_item) + # Helper function to parse tool types - supports both repeated params and comma-separated values + def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]: + if tool_types_input is None: + return None - # Validate each type against the ToolType enum - valid_types = [] - valid_values = [tt.value for tt in ToolType] + # Flatten any comma-separated values and validate against ToolType enum + flattened_types = [] + for item in tool_types_input: + # Split by comma in case user provided comma-separated values + types_in_item = [t.strip() for t in item.split(",") if t.strip()] + flattened_types.extend(types_in_item) - for tool_type in flattened_types: - if tool_type not in valid_values: - raise HTTPException( - status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}" - ) - valid_types.append(tool_type) + # Validate each type against the ToolType enum + valid_types = [] + valid_values = [tt.value for tt in ToolType] - return valid_types if valid_types else None + for tool_type in flattened_types: + if tool_type not in valid_values: + raise HTTPException(status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}") + valid_types.append(tool_type) - # Parse and validate tool types (same logic as list_tools) - tool_types_str = parse_tool_types(tool_types) - exclude_tool_types_str = parse_tool_types(exclude_tool_types) + return valid_types if valid_types else None - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + # Parse and validate tool types (same logic as list_tools) + tool_types_str = parse_tool_types(tool_types) + exclude_tool_types_str = parse_tool_types(exclude_tool_types) - # Combine single name with names list for unified processing (same logic as list_tools) - combined_names = [] - if name is not None: - combined_names.append(name) - if names is not None: - combined_names.extend(names) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - # Use None if no names specified, otherwise use the combined list - final_names = combined_names if combined_names else None + # Combine single name with names list for unified processing (same logic as list_tools) + combined_names = [] + if name is not None: + combined_names.append(name) + if names is not None: + combined_names.extend(names) - # Helper function to parse tool IDs - supports both repeated params and comma-separated values - def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]: - if tool_ids_input is None: - return None + # Use None if no names specified, otherwise use the combined list + final_names = combined_names if combined_names else None - # Flatten any comma-separated values - flattened_ids = [] - for item in tool_ids_input: - # Split by comma in case user provided comma-separated values - ids_in_item = [id.strip() for id in item.split(",") if id.strip()] - flattened_ids.extend(ids_in_item) + # Helper function to parse tool IDs - supports both repeated params and comma-separated values + def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]: + if tool_ids_input is None: + return None - return flattened_ids if flattened_ids else None + # Flatten any comma-separated values + flattened_ids = [] + for item in tool_ids_input: + # Split by comma in case user provided comma-separated values + ids_in_item = [id.strip() for id in item.split(",") if id.strip()] + flattened_ids.extend(ids_in_item) - # Parse tool IDs (same logic as list_tools) - final_tool_ids = parse_tool_ids(tool_ids) + return flattened_ids if flattened_ids else None - # Get the count of tools using unified query - return await server.tool_manager.count_tools_async( - actor=actor, - tool_types=tool_types_str, - exclude_tool_types=exclude_tool_types_str, - names=final_names, - tool_ids=final_tool_ids, - search=search, - return_only_letta_tools=return_only_letta_tools, - exclude_letta_tools=exclude_letta_tools, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + # Parse tool IDs (same logic as list_tools) + final_tool_ids = parse_tool_ids(tool_ids) + + # Get the count of tools using unified query + return await server.tool_manager.count_tools_async( + actor=actor, + tool_types=tool_types_str, + exclude_tool_types=exclude_tool_types_str, + names=final_names, + tool_ids=final_tool_ids, + search=search, + return_only_letta_tools=return_only_letta_tools, + exclude_letta_tools=exclude_letta_tools, + ) @router.get("/{tool_id}", response_model=Tool, operation_id="retrieve_tool") @@ -191,81 +194,77 @@ async def list_tools( """ Get a list of all tools available to agents. """ - try: - # Helper function to parse tool types - supports both repeated params and comma-separated values - def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]: - if tool_types_input is None: - return None - # Flatten any comma-separated values and validate against ToolType enum - flattened_types = [] - for item in tool_types_input: - # Split by comma in case user provided comma-separated values - types_in_item = [t.strip() for t in item.split(",") if t.strip()] - flattened_types.extend(types_in_item) + # Helper function to parse tool types - supports both repeated params and comma-separated values + def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]: + if tool_types_input is None: + return None - # Validate each type against the ToolType enum - valid_types = [] - valid_values = [tt.value for tt in ToolType] + # Flatten any comma-separated values and validate against ToolType enum + flattened_types = [] + for item in tool_types_input: + # Split by comma in case user provided comma-separated values + types_in_item = [t.strip() for t in item.split(",") if t.strip()] + flattened_types.extend(types_in_item) - for tool_type in flattened_types: - if tool_type not in valid_values: - raise HTTPException( - status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}" - ) - valid_types.append(tool_type) + # Validate each type against the ToolType enum + valid_types = [] + valid_values = [tt.value for tt in ToolType] - return valid_types if valid_types else None + for tool_type in flattened_types: + if tool_type not in valid_values: + raise HTTPException(status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}") + valid_types.append(tool_type) - # Parse and validate tool types - tool_types_str = parse_tool_types(tool_types) - exclude_tool_types_str = parse_tool_types(exclude_tool_types) + return valid_types if valid_types else None - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + # Parse and validate tool types + tool_types_str = parse_tool_types(tool_types) + exclude_tool_types_str = parse_tool_types(exclude_tool_types) - # Combine single name with names list for unified processing - combined_names = [] - if name is not None: - combined_names.append(name) - if names is not None: - combined_names.extend(names) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - # Use None if no names specified, otherwise use the combined list - final_names = combined_names if combined_names else None + # Combine single name with names list for unified processing + combined_names = [] + if name is not None: + combined_names.append(name) + if names is not None: + combined_names.extend(names) - # Helper function to parse tool IDs - supports both repeated params and comma-separated values - def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]: - if tool_ids_input is None: - return None + # Use None if no names specified, otherwise use the combined list + final_names = combined_names if combined_names else None - # Flatten any comma-separated values - flattened_ids = [] - for item in tool_ids_input: - # Split by comma in case user provided comma-separated values - ids_in_item = [id.strip() for id in item.split(",") if id.strip()] - flattened_ids.extend(ids_in_item) + # Helper function to parse tool IDs - supports both repeated params and comma-separated values + def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]: + if tool_ids_input is None: + return None - return flattened_ids if flattened_ids else None + # Flatten any comma-separated values + flattened_ids = [] + for item in tool_ids_input: + # Split by comma in case user provided comma-separated values + ids_in_item = [id.strip() for id in item.split(",") if id.strip()] + flattened_ids.extend(ids_in_item) - # Parse tool IDs - final_tool_ids = parse_tool_ids(tool_ids) + return flattened_ids if flattened_ids else None - # Get the list of tools using unified query - return await server.tool_manager.list_tools_async( - actor=actor, - before=before, - after=after, - limit=limit, - ascending=(order == "asc"), - tool_types=tool_types_str, - exclude_tool_types=exclude_tool_types_str, - names=final_names, - tool_ids=final_tool_ids, - search=search, - return_only_letta_tools=return_only_letta_tools, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + # Parse tool IDs + final_tool_ids = parse_tool_ids(tool_ids) + + # Get the list of tools using unified query + return await server.tool_manager.list_tools_async( + actor=actor, + before=before, + after=after, + limit=limit, + ascending=(order == "asc"), + tool_types=tool_types_str, + exclude_tool_types=exclude_tool_types_str, + names=final_names, + tool_ids=final_tool_ids, + search=search, + return_only_letta_tools=return_only_letta_tools, + ) @router.post("/", response_model=Tool, operation_id="create_tool") @@ -379,38 +378,18 @@ async def list_mcp_tools_by_server( """ Get a list of all tools for a specific MCP server """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) mcp_tools = await server.mcp_manager.list_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor) return mcp_tools - except Exception as e: - if isinstance(e, ConnectError) or isinstance(e, ConnectionError): - raise HTTPException( - status_code=404, - detail={ - "code": "MCPListToolsError", - "message": str(e), - "mcp_server_name": mcp_server_name, - }, - ) - if isinstance(e, HTTPStatusError): - raise HTTPException( - status_code=401, - detail={ - "code": "MCPListToolsError", - "message": str(e), - "mcp_server_name": mcp_server_name, - }, - ) + except (ConnectError, ConnectionError) as e: + raise LettaMCPConnectionError(str(e), server_name=mcp_server_name) + except HTTPStatusError as e: + # HTTPStatusError from the MCP server likely means auth issue + if e.response.status_code == 401: + raise LettaMCPConnectionError(f"Authentication failed: {e}", server_name=mcp_server_name) else: - raise HTTPException( - status_code=500, - detail={ - "code": "MCPListToolsError", - "message": str(e), - "mcp_server_name": mcp_server_name, - }, - ) + raise LettaMCPConnectionError(f"HTTP error from MCP server: {e}", server_name=mcp_server_name) @router.post("/mcp/servers/{mcp_server_name}/resync", operation_id="resync_mcp_server_tools") @@ -430,29 +409,8 @@ async def resync_mcp_server_tools( Returns a summary of changes made. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - - try: - result = await server.mcp_manager.resync_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor, agent_id=agent_id) - return result - except ValueError as e: - raise HTTPException( - status_code=404, - detail={ - "code": "MCPServerNotFoundError", - "message": str(e), - "mcp_server_name": mcp_server_name, - }, - ) - except Exception as e: - logger.error(f"Unexpected error refreshing MCP server tools: {e}") - raise HTTPException( - status_code=404, - detail={ - "code": "MCPRefreshError", - "message": f"Failed to refresh MCP server tools: {str(e)}", - "mcp_server_name": mcp_server_name, - }, - ) + result = await server.mcp_manager.resync_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor, agent_id=agent_id) + return result @router.post("/mcp/servers/{mcp_server_name}/{mcp_tool_name}", response_model=Tool, operation_id="add_mcp_tool") @@ -470,25 +428,8 @@ async def add_mcp_tool( if tool_settings.mcp_read_from_config: try: available_tools = await server.get_tools_from_mcp_server(mcp_server_name=mcp_server_name) - except ValueError as e: - # ValueError means that the MCP server name doesn't exist - raise HTTPException( - status_code=400, # Bad Request - detail={ - "code": "MCPServerNotFoundError", - "message": str(e), - "mcp_server_name": mcp_server_name, - }, - ) except MCPTimeoutError as e: - raise HTTPException( - status_code=408, # Timeout - detail={ - "code": "MCPTimeoutError", - "message": str(e), - "mcp_server_name": mcp_server_name, - }, - ) + raise LettaMCPTimeoutError(str(e), server_name=mcp_server_name) # See if the tool is in the available list mcp_tool = None @@ -497,27 +438,18 @@ async def add_mcp_tool( mcp_tool = tool break if not mcp_tool: - raise HTTPException( - status_code=400, # Bad Request - detail={ - "code": "MCPToolNotFoundError", - "message": f"Tool {mcp_tool_name} not found in MCP server {mcp_server_name} - available tools: {', '.join([tool.name for tool in available_tools])}", - "mcp_tool_name": mcp_tool_name, - }, + raise LettaInvalidArgumentError( + f"Tool {mcp_tool_name} not found in MCP server {mcp_server_name} - available tools: {', '.join([tool.name for tool in available_tools])}", + argument_name="mcp_tool_name", ) # Check tool health - reject only INVALID tools if mcp_tool.health: if mcp_tool.health.status == "INVALID": - raise HTTPException( - status_code=400, - detail={ - "code": "MCPToolSchemaInvalid", - "message": f"Tool {mcp_tool_name} has an invalid schema and cannot be attached", - "mcp_tool_name": mcp_tool_name, - "health_status": mcp_tool.health.status, - "reasons": mcp_tool.health.reasons, - }, + raise LettaInvalidMCPSchemaError( + server_name=mcp_server_name, + mcp_tool_name=mcp_tool_name, + reasons=mcp_tool.health.reasons, ) tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) @@ -544,40 +476,27 @@ async def add_mcp_server_to_config( """ Add a new MCP server to the Letta MCP server config """ - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - if tool_settings.mcp_read_from_config: - # write to config file - return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True) - else: - # log to DB - # Check if stdio servers are disabled - if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio: - raise HTTPException( - status_code=400, - detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server", - ) + if tool_settings.mcp_read_from_config: + # write to config file + return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True) + else: + # log to DB + # Check if stdio servers are disabled + if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio: + raise HTTPException( + status_code=400, + detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server", + ) - # Create MCP server and optimistically sync tools - # The mcp_manager will handle encryption of sensitive fields - await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor) + # Create MCP server and optimistically sync tools + # The mcp_manager will handle encryption of sensitive fields + await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor) - # TODO: don't do this in the future (just return MCPServer) - all_servers = await server.mcp_manager.list_mcp_servers(actor=actor) - return [server.to_config() for server in all_servers] - except UniqueConstraintViolationError: - # If server name already exists, throw 409 conflict error - raise HTTPException( - status_code=409, - detail={ - "code": "MCPServerNameAlreadyExistsError", - "message": f"MCP server with name '{request.server_name}' already exists", - "server_name": request.server_name, - }, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") + # TODO: don't do this in the future (just return MCPServer) + all_servers = await server.mcp_manager.list_mcp_servers(actor=actor) + return [server.to_config() for server in all_servers] @router.patch( @@ -594,21 +513,15 @@ async def update_mcp_server( """ Update an existing MCP server configuration """ - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - if tool_settings.mcp_read_from_config: - raise HTTPException(status_code=501, detail="Update not implemented for config file mode, config files to be deprecated.") - else: - updated_server = await server.mcp_manager.update_mcp_server_by_name( - mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor - ) - return updated_server.to_config() - except HTTPException: - # Re-raise HTTP exceptions (like 404) - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") + if tool_settings.mcp_read_from_config: + raise HTTPException(status_code=501, detail="Update not implemented for config file mode, config files to be deprecated.") + else: + updated_server = await server.mcp_manager.update_mcp_server_by_name( + mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor + ) + return updated_server.to_config() @router.delete( @@ -660,32 +573,9 @@ async def test_mcp_server( return {"status": "success", "tools": tools} except ConnectionError as e: - raise HTTPException( - status_code=400, - detail={ - "code": "MCPServerConnectionError", - "message": str(e), - "server_name": request.server_name, - }, - ) + raise LettaMCPConnectionError(str(e), server_name=request.server_name) except MCPTimeoutError as e: - raise HTTPException( - status_code=408, - detail={ - "code": "MCPTimeoutError", - "message": f"MCP server connection timed out: {str(e)}", - "server_name": request.server_name, - }, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail={ - "code": "MCPServerTestError", - "message": f"Failed to test MCP server: {str(e)}", - "server_name": request.server_name, - }, - ) + raise LettaMCPTimeoutError(f"MCP server connection timed out: {str(e)}", server_name=request.server_name) finally: if client: try: @@ -802,18 +692,14 @@ async def generate_json_schema( Generate a JSON schema from the given source code defining a function or class. Supports both Python and TypeScript source code. """ - try: - if request.source_type == "typescript": - from letta.functions.typescript_parser import derive_typescript_json_schema + if request.source_type == "typescript": + from letta.functions.typescript_parser import derive_typescript_json_schema - schema = derive_typescript_json_schema(source_code=request.code) - else: - # Default to Python for backwards compatibility - schema = derive_openai_json_schema(source_code=request.code) - return schema - - except Exception as e: - raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}") + schema = derive_typescript_json_schema(source_code=request.code) + else: + # Default to Python for backwards compatibility + schema = derive_openai_json_schema(source_code=request.code) + return schema # TODO: @jnjpng move this and other models above to appropriate file for schemas @@ -840,14 +726,9 @@ async def execute_mcp_tool( # Get the MCP server by name mcp_server = await server.mcp_manager.get_mcp_server(mcp_server_name, actor) if not mcp_server: - raise HTTPException( - status_code=404, - detail={ - "code": "MCPServerNotFound", - "message": f"MCP server '{mcp_server_name}' not found", - "server_name": mcp_server_name, - }, - ) + from letta.orm.errors import NoResultFound + + raise NoResultFound(f"MCP server '{mcp_server_name}' not found") # Create client and connect server_config = mcp_server.to_config() @@ -862,19 +743,6 @@ async def execute_mcp_tool( "result": result, "success": success, } - except HTTPException: - raise - except Exception as e: - logger.warning(f"Error executing MCP tool: {str(e)}") - raise HTTPException( - status_code=500, - detail={ - "code": "MCPToolExecutionError", - "message": f"Failed to execute MCP tool: {str(e)}", - "server_name": mcp_server_name, - "tool_name": tool_name, - }, - ) finally: if client: try: @@ -944,70 +812,64 @@ async def generate_tool_from_prompt( """ Generate a tool from the given user prompt. """ - response_data = None + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + llm_config = await server.get_cached_llm_config_async(actor=actor, handle=request.handle or "anthropic/claude-3-5-sonnet-20240620") + formatted_prompt = ( + f"Generate a python function named {request.tool_name} using the instructions below " + + (f"based on this starter code: \n\n```\n{request.starter_code}\n```\n\n" if request.starter_code else "\n") + + (f"Note the following validation errors: \n{' '.join(request.validation_errors)}\n\n" if request.validation_errors else "\n") + + f"Instructions: {request.prompt}" + ) + llm_client = LLMClient.create( + provider_type=llm_config.model_endpoint_type, + actor=actor, + ) + assert llm_client is not None - try: - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - llm_config = await server.get_cached_llm_config_async(actor=actor, handle=request.handle or "anthropic/claude-3-5-sonnet-20240620") - formatted_prompt = ( - f"Generate a python function named {request.tool_name} using the instructions below " - + (f"based on this starter code: \n\n```\n{request.starter_code}\n```\n\n" if request.starter_code else "\n") - + (f"Note the following validation errors: \n{' '.join(request.validation_errors)}\n\n" if request.validation_errors else "\n") - + f"Instructions: {request.prompt}" - ) - llm_client = LLMClient.create( - provider_type=llm_config.model_endpoint_type, - actor=actor, - ) - assert llm_client is not None + assistant_message_ack = "Understood, I will respond with generated python source code and sample arguments that can be used to test the functionality once I receive the user prompt. I'm ready." - assistant_message_ack = "Understood, I will respond with generated python source code and sample arguments that can be used to test the functionality once I receive the user prompt. I'm ready." + input_messages = [ + Message(role=MessageRole.system, content=[TextContent(text=get_system_text("memgpt_generate_tool"))]), + Message(role=MessageRole.assistant, content=[TextContent(text=assistant_message_ack)]), + Message(role=MessageRole.user, content=[TextContent(text=formatted_prompt)]), + ] - input_messages = [ - Message(role=MessageRole.system, content=[TextContent(text=get_system_text("memgpt_generate_tool"))]), - Message(role=MessageRole.assistant, content=[TextContent(text=assistant_message_ack)]), - Message(role=MessageRole.user, content=[TextContent(text=formatted_prompt)]), - ] - - tool = { - "name": "generate_tool", - "description": "This method generates the raw source code for a custom tool that can be attached to and agent for llm invocation.", - "parameters": { - "type": "object", - "properties": { - "raw_source_code": {"type": "string", "description": "The raw python source code of the custom tool."}, - "sample_args_json": { - "type": "string", - "description": "The JSON dict that contains sample args for a test run of the python function. Key is the name of the function parameter and value is an example argument that is passed in.", - }, - "pip_requirements_json": { - "type": "string", - "description": "Optional JSON dict that contains pip packages to be installed if needed by the source code. Key is the name of the pip package and value is the version number.", - }, + tool = { + "name": "generate_tool", + "description": "This method generates the raw source code for a custom tool that can be attached to and agent for llm invocation.", + "parameters": { + "type": "object", + "properties": { + "raw_source_code": {"type": "string", "description": "The raw python source code of the custom tool."}, + "sample_args_json": { + "type": "string", + "description": "The JSON dict that contains sample args for a test run of the python function. Key is the name of the function parameter and value is an example argument that is passed in.", + }, + "pip_requirements_json": { + "type": "string", + "description": "Optional JSON dict that contains pip packages to be installed if needed by the source code. Key is the name of the pip package and value is the version number.", }, - "required": ["raw_source_code", "sample_args_json", "pip_requirements_json"], }, - } - request_data = llm_client.build_request_data( - AgentType.letta_v1_agent, - input_messages, - llm_config, - tools=[tool], - ) - response_data = await llm_client.request_async(request_data, llm_config) - response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config) - output = json.loads(response.choices[0].message.tool_calls[0].function.arguments) - pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()] - return GenerateToolOutput( - tool=Tool( - name=request.tool_name, - source_type="python", - source_code=output["raw_source_code"], - pip_requirements=pip_requirements, - ), - sample_args=json.loads(output["sample_args_json"]), - response=response.choices[0].message.content, - ) - except Exception as e: - logger.error(f"Failed to generate tool: {str(e)}. Raw response: {response_data}") - raise HTTPException(status_code=500, detail=f"Failed to generate tool: {str(e)}") + "required": ["raw_source_code", "sample_args_json", "pip_requirements_json"], + }, + } + request_data = llm_client.build_request_data( + AgentType.letta_v1_agent, + input_messages, + llm_config, + tools=[tool], + ) + response_data = await llm_client.request_async(request_data, llm_config) + response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config) + output = json.loads(response.choices[0].message.tool_calls[0].function.arguments) + pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()] + return GenerateToolOutput( + tool=Tool( + name=request.tool_name, + source_type="python", + source_code=output["raw_source_code"], + pip_requirements=pip_requirements, + ), + sample_args=json.loads(output["sample_args_json"]), + response=response.choices[0].message.content, + ) diff --git a/tests/mcp_tests/test_mcp_schema_validation.py b/tests/mcp_tests/test_mcp_schema_validation.py index 7a92630f..cdadd155 100644 --- a/tests/mcp_tests/test_mcp_schema_validation.py +++ b/tests/mcp_tests/test_mcp_schema_validation.py @@ -163,12 +163,14 @@ async def test_add_mcp_tool_rejects_invalid_schemas(): # Should raise HTTPException for invalid schema headers = HeaderParams(actor_id="test_user") - with pytest.raises(HTTPException) as exc_info: + from letta.errors import LettaInvalidMCPSchemaError + + with pytest.raises(LettaInvalidMCPSchemaError) as exc_info: await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, headers=headers) - assert exc_info.value.status_code == 400 - assert "invalid schema" in exc_info.value.detail["message"].lower() - assert exc_info.value.detail["health_status"] == SchemaHealth.INVALID.value + assert "invalid schema" in exc_info.value.message.lower() + assert exc_info.value.details["mcp_tool_name"] == "test_tool" + assert exc_info.value.details["reasons"] == ["Missing 'type' at root level"] def test_mcp_schema_healing_for_optional_fields():