feat: exception handling middleware for sandbox_configs + identities + tools (#5143)
This commit is contained in:
committed by
Caren Thomas
parent
307c85ca9a
commit
324933edd3
@@ -109,6 +109,18 @@ class LettaMCPError(LettaError):
|
||||
"""Base error for MCP-related issues."""
|
||||
|
||||
|
||||
class LettaInvalidMCPSchemaError(LettaMCPError):
|
||||
"""Error raised when an invalid MCP schema is provided."""
|
||||
|
||||
def __init__(self, server_name: str, mcp_tool_name: str, reasons: List[str]):
|
||||
details = {"server_name": server_name, "mcp_tool_name": mcp_tool_name, "reasons": reasons}
|
||||
super().__init__(
|
||||
message=f"MCP tool {mcp_tool_name} has an invalid schema and cannot be attached - reasons: {reasons}",
|
||||
code=ErrorCode.INVALID_ARGUMENT,
|
||||
details=details,
|
||||
)
|
||||
|
||||
|
||||
class LettaMCPConnectionError(LettaMCPError):
|
||||
"""Error raised when unable to connect to MCP server."""
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from letta.errors import (
|
||||
BedrockPermissionError,
|
||||
LettaAgentNotFoundError,
|
||||
LettaInvalidArgumentError,
|
||||
LettaInvalidMCPSchemaError,
|
||||
LettaMCPConnectionError,
|
||||
LettaMCPTimeoutError,
|
||||
LettaToolCreateError,
|
||||
@@ -264,6 +265,7 @@ def create_application() -> "FastAPI":
|
||||
|
||||
# 408 Timeout errors
|
||||
app.add_exception_handler(LettaMCPTimeoutError, _error_handler_408)
|
||||
app.add_exception_handler(LettaInvalidMCPSchemaError, _error_handler_400)
|
||||
|
||||
# 409 Conflict errors
|
||||
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query
|
||||
|
||||
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -39,26 +39,19 @@ async def list_identities(
|
||||
"""
|
||||
Get a list of all identities in the database
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
identities = await server.identity_manager.list_identities_async(
|
||||
name=name,
|
||||
project_id=project_id,
|
||||
identifier_key=identifier_key,
|
||||
identity_type=identity_type,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=(order == "asc"),
|
||||
actor=actor,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
identities = await server.identity_manager.list_identities_async(
|
||||
name=name,
|
||||
project_id=project_id,
|
||||
identifier_key=identifier_key,
|
||||
identity_type=identity_type,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=(order == "asc"),
|
||||
actor=actor,
|
||||
)
|
||||
return identities
|
||||
|
||||
|
||||
@@ -70,15 +63,11 @@ async def count_identities(
|
||||
"""
|
||||
Get count of all identities for a user
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.size_async(actor=actor)
|
||||
except NoResultFound:
|
||||
return 0
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
|
||||
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
|
||||
@@ -87,11 +76,8 @@ async def retrieve_identity(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor)
|
||||
|
||||
|
||||
@router.post("/", tags=["identities"], response_model=Identity, operation_id="create_identity")
|
||||
@@ -103,21 +89,8 @@ async def create_identity(
|
||||
None, alias="X-Project", description="The project slug to associate with the identity (cloud only)."
|
||||
), # Only handled by next js middleware
|
||||
):
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.create_identity_async(identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except UniqueConstraintViolationError:
|
||||
if identity.project_id:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"An identity with identifier key {identity.identifier_key} already exists for project {identity.project_id}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=409, detail=f"An identity with identifier key {identity.identifier_key} already exists")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.create_identity_async(identity=identity, actor=actor)
|
||||
|
||||
|
||||
@router.put("/", tags=["identities"], response_model=Identity, operation_id="upsert_identity")
|
||||
@@ -129,15 +102,8 @@ async def upsert_identity(
|
||||
None, alias="X-Project", description="The project slug to associate with the identity (cloud only)."
|
||||
), # Only handled by next js middleware
|
||||
):
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity")
|
||||
@@ -147,15 +113,8 @@ async def modify_identity(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor)
|
||||
|
||||
|
||||
@router.put("/{identity_id}/properties", tags=["identities"], operation_id="upsert_identity_properties")
|
||||
@@ -165,15 +124,8 @@ async def upsert_identity_properties(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
):
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor)
|
||||
|
||||
|
||||
@router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity")
|
||||
@@ -185,15 +137,8 @@ async def delete_identity(
|
||||
"""
|
||||
Delete an identity by its identifier key
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{identity_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_identity")
|
||||
@@ -218,20 +163,15 @@ async def list_agents_for_identity(
|
||||
"""
|
||||
Get all agents associated with the specified identity.
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.list_agents_for_identity_async(
|
||||
identity_id=identity_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=(order == "asc"),
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=f"Identity with id={identity_id} not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.list_agents_for_identity_async(
|
||||
identity_id=identity_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=(order == "asc"),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{identity_id}/blocks", response_model=List[Block], operation_id="list_blocks_for_identity")
|
||||
@@ -256,17 +196,12 @@ async def list_blocks_for_identity(
|
||||
"""
|
||||
Get all blocks associated with the specified identity.
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.list_blocks_for_identity_async(
|
||||
identity_id=identity_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=(order == "asc"),
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=f"Identity with id={identity_id} not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.identity_manager.list_blocks_for_identity_async(
|
||||
identity_id=identity_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=(order == "asc"),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,9 @@ import os
|
||||
import shutil
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import SandboxType
|
||||
from letta.schemas.environment_variables import (
|
||||
@@ -68,9 +69,8 @@ async def create_custom_local_sandbox_config(
|
||||
"""
|
||||
# Ensure the incoming config is of type LOCAL
|
||||
if local_sandbox_config.type != SandboxType.LOCAL:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provided config must be of type '{SandboxType.LOCAL.value}'.",
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Provided config must be of type '{SandboxType.LOCAL.value}'.", argument_name="local_sandbox_config.type"
|
||||
)
|
||||
|
||||
# Retrieve the user (actor)
|
||||
@@ -138,25 +138,16 @@ async def force_recreate_local_sandbox_venv(
|
||||
|
||||
# Check if venv exists, and delete if necessary
|
||||
if os.path.isdir(venv_path):
|
||||
try:
|
||||
shutil.rmtree(venv_path)
|
||||
logger.info(f"Deleted existing virtual environment at: {venv_path}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete existing venv: {e}")
|
||||
shutil.rmtree(venv_path)
|
||||
logger.info(f"Deleted existing virtual environment at: {venv_path}")
|
||||
|
||||
# Recreate the virtual environment
|
||||
try:
|
||||
create_venv_for_local_sandbox(sandbox_dir_path=sandbox_dir, venv_path=str(venv_path), env=os.environ.copy(), force_recreate=True)
|
||||
logger.info(f"Successfully recreated virtual environment at: {venv_path}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to recreate venv: {e}")
|
||||
create_venv_for_local_sandbox(sandbox_dir_path=sandbox_dir, venv_path=str(venv_path), env=os.environ.copy(), force_recreate=True)
|
||||
logger.info(f"Successfully recreated virtual environment at: {venv_path}")
|
||||
|
||||
# Install pip requirements
|
||||
try:
|
||||
install_pip_requirements_for_sandbox(local_configs=local_configs, env=os.environ.copy())
|
||||
logger.info(f"Successfully installed pip requirements for venv at: {venv_path}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to install pip requirements: {e}")
|
||||
install_pip_requirements_for_sandbox(local_configs=local_configs, env=os.environ.copy())
|
||||
logger.info(f"Successfully installed pip requirements for venv at: {venv_path}")
|
||||
|
||||
return sbx_config
|
||||
|
||||
|
||||
@@ -7,7 +7,14 @@ from httpx import ConnectError, HTTPStatusError
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.errors import LettaToolCreateError, LettaToolNameConflictError
|
||||
from letta.errors import (
|
||||
LettaInvalidArgumentError,
|
||||
LettaInvalidMCPSchemaError,
|
||||
LettaMCPConnectionError,
|
||||
LettaMCPTimeoutError,
|
||||
LettaToolCreateError,
|
||||
LettaToolNameConflictError,
|
||||
)
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.mcp_client.exceptions import MCPTimeoutError
|
||||
from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
@@ -70,78 +77,74 @@ async def count_tools(
|
||||
"""
|
||||
Get a count of all tools available to agents belonging to the org of the user.
|
||||
"""
|
||||
try:
|
||||
# Helper function to parse tool types - supports both repeated params and comma-separated values
|
||||
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_types_input is None:
|
||||
return None
|
||||
|
||||
# Flatten any comma-separated values and validate against ToolType enum
|
||||
flattened_types = []
|
||||
for item in tool_types_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
|
||||
flattened_types.extend(types_in_item)
|
||||
# Helper function to parse tool types - supports both repeated params and comma-separated values
|
||||
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_types_input is None:
|
||||
return None
|
||||
|
||||
# Validate each type against the ToolType enum
|
||||
valid_types = []
|
||||
valid_values = [tt.value for tt in ToolType]
|
||||
# Flatten any comma-separated values and validate against ToolType enum
|
||||
flattened_types = []
|
||||
for item in tool_types_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
|
||||
flattened_types.extend(types_in_item)
|
||||
|
||||
for tool_type in flattened_types:
|
||||
if tool_type not in valid_values:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
|
||||
)
|
||||
valid_types.append(tool_type)
|
||||
# Validate each type against the ToolType enum
|
||||
valid_types = []
|
||||
valid_values = [tt.value for tt in ToolType]
|
||||
|
||||
return valid_types if valid_types else None
|
||||
for tool_type in flattened_types:
|
||||
if tool_type not in valid_values:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}")
|
||||
valid_types.append(tool_type)
|
||||
|
||||
# Parse and validate tool types (same logic as list_tools)
|
||||
tool_types_str = parse_tool_types(tool_types)
|
||||
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
|
||||
return valid_types if valid_types else None
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
# Parse and validate tool types (same logic as list_tools)
|
||||
tool_types_str = parse_tool_types(tool_types)
|
||||
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
|
||||
|
||||
# Combine single name with names list for unified processing (same logic as list_tools)
|
||||
combined_names = []
|
||||
if name is not None:
|
||||
combined_names.append(name)
|
||||
if names is not None:
|
||||
combined_names.extend(names)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Use None if no names specified, otherwise use the combined list
|
||||
final_names = combined_names if combined_names else None
|
||||
# Combine single name with names list for unified processing (same logic as list_tools)
|
||||
combined_names = []
|
||||
if name is not None:
|
||||
combined_names.append(name)
|
||||
if names is not None:
|
||||
combined_names.extend(names)
|
||||
|
||||
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
|
||||
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_ids_input is None:
|
||||
return None
|
||||
# Use None if no names specified, otherwise use the combined list
|
||||
final_names = combined_names if combined_names else None
|
||||
|
||||
# Flatten any comma-separated values
|
||||
flattened_ids = []
|
||||
for item in tool_ids_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
|
||||
flattened_ids.extend(ids_in_item)
|
||||
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
|
||||
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_ids_input is None:
|
||||
return None
|
||||
|
||||
return flattened_ids if flattened_ids else None
|
||||
# Flatten any comma-separated values
|
||||
flattened_ids = []
|
||||
for item in tool_ids_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
|
||||
flattened_ids.extend(ids_in_item)
|
||||
|
||||
# Parse tool IDs (same logic as list_tools)
|
||||
final_tool_ids = parse_tool_ids(tool_ids)
|
||||
return flattened_ids if flattened_ids else None
|
||||
|
||||
# Get the count of tools using unified query
|
||||
return await server.tool_manager.count_tools_async(
|
||||
actor=actor,
|
||||
tool_types=tool_types_str,
|
||||
exclude_tool_types=exclude_tool_types_str,
|
||||
names=final_names,
|
||||
tool_ids=final_tool_ids,
|
||||
search=search,
|
||||
return_only_letta_tools=return_only_letta_tools,
|
||||
exclude_letta_tools=exclude_letta_tools,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Parse tool IDs (same logic as list_tools)
|
||||
final_tool_ids = parse_tool_ids(tool_ids)
|
||||
|
||||
# Get the count of tools using unified query
|
||||
return await server.tool_manager.count_tools_async(
|
||||
actor=actor,
|
||||
tool_types=tool_types_str,
|
||||
exclude_tool_types=exclude_tool_types_str,
|
||||
names=final_names,
|
||||
tool_ids=final_tool_ids,
|
||||
search=search,
|
||||
return_only_letta_tools=return_only_letta_tools,
|
||||
exclude_letta_tools=exclude_letta_tools,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=Tool, operation_id="retrieve_tool")
|
||||
@@ -191,81 +194,77 @@ async def list_tools(
|
||||
"""
|
||||
Get a list of all tools available to agents.
|
||||
"""
|
||||
try:
|
||||
# Helper function to parse tool types - supports both repeated params and comma-separated values
|
||||
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_types_input is None:
|
||||
return None
|
||||
|
||||
# Flatten any comma-separated values and validate against ToolType enum
|
||||
flattened_types = []
|
||||
for item in tool_types_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
|
||||
flattened_types.extend(types_in_item)
|
||||
# Helper function to parse tool types - supports both repeated params and comma-separated values
|
||||
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_types_input is None:
|
||||
return None
|
||||
|
||||
# Validate each type against the ToolType enum
|
||||
valid_types = []
|
||||
valid_values = [tt.value for tt in ToolType]
|
||||
# Flatten any comma-separated values and validate against ToolType enum
|
||||
flattened_types = []
|
||||
for item in tool_types_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
|
||||
flattened_types.extend(types_in_item)
|
||||
|
||||
for tool_type in flattened_types:
|
||||
if tool_type not in valid_values:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
|
||||
)
|
||||
valid_types.append(tool_type)
|
||||
# Validate each type against the ToolType enum
|
||||
valid_types = []
|
||||
valid_values = [tt.value for tt in ToolType]
|
||||
|
||||
return valid_types if valid_types else None
|
||||
for tool_type in flattened_types:
|
||||
if tool_type not in valid_values:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}")
|
||||
valid_types.append(tool_type)
|
||||
|
||||
# Parse and validate tool types
|
||||
tool_types_str = parse_tool_types(tool_types)
|
||||
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
|
||||
return valid_types if valid_types else None
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
# Parse and validate tool types
|
||||
tool_types_str = parse_tool_types(tool_types)
|
||||
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
|
||||
|
||||
# Combine single name with names list for unified processing
|
||||
combined_names = []
|
||||
if name is not None:
|
||||
combined_names.append(name)
|
||||
if names is not None:
|
||||
combined_names.extend(names)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
# Use None if no names specified, otherwise use the combined list
|
||||
final_names = combined_names if combined_names else None
|
||||
# Combine single name with names list for unified processing
|
||||
combined_names = []
|
||||
if name is not None:
|
||||
combined_names.append(name)
|
||||
if names is not None:
|
||||
combined_names.extend(names)
|
||||
|
||||
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
|
||||
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_ids_input is None:
|
||||
return None
|
||||
# Use None if no names specified, otherwise use the combined list
|
||||
final_names = combined_names if combined_names else None
|
||||
|
||||
# Flatten any comma-separated values
|
||||
flattened_ids = []
|
||||
for item in tool_ids_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
|
||||
flattened_ids.extend(ids_in_item)
|
||||
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
|
||||
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
|
||||
if tool_ids_input is None:
|
||||
return None
|
||||
|
||||
return flattened_ids if flattened_ids else None
|
||||
# Flatten any comma-separated values
|
||||
flattened_ids = []
|
||||
for item in tool_ids_input:
|
||||
# Split by comma in case user provided comma-separated values
|
||||
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
|
||||
flattened_ids.extend(ids_in_item)
|
||||
|
||||
# Parse tool IDs
|
||||
final_tool_ids = parse_tool_ids(tool_ids)
|
||||
return flattened_ids if flattened_ids else None
|
||||
|
||||
# Get the list of tools using unified query
|
||||
return await server.tool_manager.list_tools_async(
|
||||
actor=actor,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=(order == "asc"),
|
||||
tool_types=tool_types_str,
|
||||
exclude_tool_types=exclude_tool_types_str,
|
||||
names=final_names,
|
||||
tool_ids=final_tool_ids,
|
||||
search=search,
|
||||
return_only_letta_tools=return_only_letta_tools,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Parse tool IDs
|
||||
final_tool_ids = parse_tool_ids(tool_ids)
|
||||
|
||||
# Get the list of tools using unified query
|
||||
return await server.tool_manager.list_tools_async(
|
||||
actor=actor,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=(order == "asc"),
|
||||
tool_types=tool_types_str,
|
||||
exclude_tool_types=exclude_tool_types_str,
|
||||
names=final_names,
|
||||
tool_ids=final_tool_ids,
|
||||
search=search,
|
||||
return_only_letta_tools=return_only_letta_tools,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=Tool, operation_id="create_tool")
|
||||
@@ -379,38 +378,18 @@ async def list_mcp_tools_by_server(
|
||||
"""
|
||||
Get a list of all tools for a specific MCP server
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
mcp_tools = await server.mcp_manager.list_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor)
|
||||
return mcp_tools
|
||||
except Exception as e:
|
||||
if isinstance(e, ConnectError) or isinstance(e, ConnectionError):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "MCPListToolsError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
if isinstance(e, HTTPStatusError):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"code": "MCPListToolsError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
except (ConnectError, ConnectionError) as e:
|
||||
raise LettaMCPConnectionError(str(e), server_name=mcp_server_name)
|
||||
except HTTPStatusError as e:
|
||||
# HTTPStatusError from the MCP server likely means auth issue
|
||||
if e.response.status_code == 401:
|
||||
raise LettaMCPConnectionError(f"Authentication failed: {e}", server_name=mcp_server_name)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": "MCPListToolsError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
raise LettaMCPConnectionError(f"HTTP error from MCP server: {e}", server_name=mcp_server_name)
|
||||
|
||||
|
||||
@router.post("/mcp/servers/{mcp_server_name}/resync", operation_id="resync_mcp_server_tools")
|
||||
@@ -430,29 +409,8 @@ async def resync_mcp_server_tools(
|
||||
Returns a summary of changes made.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
try:
|
||||
result = await server.mcp_manager.resync_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor, agent_id=agent_id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error refreshing MCP server tools: {e}")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "MCPRefreshError",
|
||||
"message": f"Failed to refresh MCP server tools: {str(e)}",
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
result = await server.mcp_manager.resync_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor, agent_id=agent_id)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/mcp/servers/{mcp_server_name}/{mcp_tool_name}", response_model=Tool, operation_id="add_mcp_tool")
|
||||
@@ -470,25 +428,8 @@ async def add_mcp_tool(
|
||||
if tool_settings.mcp_read_from_config:
|
||||
try:
|
||||
available_tools = await server.get_tools_from_mcp_server(mcp_server_name=mcp_server_name)
|
||||
except ValueError as e:
|
||||
# ValueError means that the MCP server name doesn't exist
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408, # Timeout
|
||||
detail={
|
||||
"code": "MCPTimeoutError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
raise LettaMCPTimeoutError(str(e), server_name=mcp_server_name)
|
||||
|
||||
# See if the tool is in the available list
|
||||
mcp_tool = None
|
||||
@@ -497,27 +438,18 @@ async def add_mcp_tool(
|
||||
mcp_tool = tool
|
||||
break
|
||||
if not mcp_tool:
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail={
|
||||
"code": "MCPToolNotFoundError",
|
||||
"message": f"Tool {mcp_tool_name} not found in MCP server {mcp_server_name} - available tools: {', '.join([tool.name for tool in available_tools])}",
|
||||
"mcp_tool_name": mcp_tool_name,
|
||||
},
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Tool {mcp_tool_name} not found in MCP server {mcp_server_name} - available tools: {', '.join([tool.name for tool in available_tools])}",
|
||||
argument_name="mcp_tool_name",
|
||||
)
|
||||
|
||||
# Check tool health - reject only INVALID tools
|
||||
if mcp_tool.health:
|
||||
if mcp_tool.health.status == "INVALID":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"code": "MCPToolSchemaInvalid",
|
||||
"message": f"Tool {mcp_tool_name} has an invalid schema and cannot be attached",
|
||||
"mcp_tool_name": mcp_tool_name,
|
||||
"health_status": mcp_tool.health.status,
|
||||
"reasons": mcp_tool.health.reasons,
|
||||
},
|
||||
raise LettaInvalidMCPSchemaError(
|
||||
server_name=mcp_server_name,
|
||||
mcp_tool_name=mcp_tool_name,
|
||||
reasons=mcp_tool.health.reasons,
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
@@ -544,40 +476,27 @@ async def add_mcp_server_to_config(
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
# write to config file
|
||||
return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
|
||||
else:
|
||||
# log to DB
|
||||
# Check if stdio servers are disabled
|
||||
if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server",
|
||||
)
|
||||
if tool_settings.mcp_read_from_config:
|
||||
# write to config file
|
||||
return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
|
||||
else:
|
||||
# log to DB
|
||||
# Check if stdio servers are disabled
|
||||
if isinstance(request, StdioServerConfig) and tool_settings.mcp_disable_stdio:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server",
|
||||
)
|
||||
|
||||
# Create MCP server and optimistically sync tools
|
||||
# The mcp_manager will handle encryption of sensitive fields
|
||||
await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
|
||||
# Create MCP server and optimistically sync tools
|
||||
# The mcp_manager will handle encryption of sensitive fields
|
||||
await server.mcp_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
|
||||
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return [server.to_config() for server in all_servers]
|
||||
except UniqueConstraintViolationError:
|
||||
# If server name already exists, throw 409 conflict error
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"code": "MCPServerNameAlreadyExistsError",
|
||||
"message": f"MCP server with name '{request.server_name}' already exists",
|
||||
"server_name": request.server_name,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return [server.to_config() for server in all_servers]
|
||||
|
||||
|
||||
@router.patch(
|
||||
@@ -594,21 +513,15 @@ async def update_mcp_server(
|
||||
"""
|
||||
Update an existing MCP server configuration
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
raise HTTPException(status_code=501, detail="Update not implemented for config file mode, config files to be deprecated.")
|
||||
else:
|
||||
updated_server = await server.mcp_manager.update_mcp_server_by_name(
|
||||
mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor
|
||||
)
|
||||
return updated_server.to_config()
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (like 404)
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
if tool_settings.mcp_read_from_config:
|
||||
raise HTTPException(status_code=501, detail="Update not implemented for config file mode, config files to be deprecated.")
|
||||
else:
|
||||
updated_server = await server.mcp_manager.update_mcp_server_by_name(
|
||||
mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor
|
||||
)
|
||||
return updated_server.to_config()
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -660,32 +573,9 @@ async def test_mcp_server(
|
||||
|
||||
return {"status": "success", "tools": tools}
|
||||
except ConnectionError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"code": "MCPServerConnectionError",
|
||||
"message": str(e),
|
||||
"server_name": request.server_name,
|
||||
},
|
||||
)
|
||||
raise LettaMCPConnectionError(str(e), server_name=request.server_name)
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408,
|
||||
detail={
|
||||
"code": "MCPTimeoutError",
|
||||
"message": f"MCP server connection timed out: {str(e)}",
|
||||
"server_name": request.server_name,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": "MCPServerTestError",
|
||||
"message": f"Failed to test MCP server: {str(e)}",
|
||||
"server_name": request.server_name,
|
||||
},
|
||||
)
|
||||
raise LettaMCPTimeoutError(f"MCP server connection timed out: {str(e)}", server_name=request.server_name)
|
||||
finally:
|
||||
if client:
|
||||
try:
|
||||
@@ -802,18 +692,14 @@ async def generate_json_schema(
|
||||
Generate a JSON schema from the given source code defining a function or class.
|
||||
Supports both Python and TypeScript source code.
|
||||
"""
|
||||
try:
|
||||
if request.source_type == "typescript":
|
||||
from letta.functions.typescript_parser import derive_typescript_json_schema
|
||||
if request.source_type == "typescript":
|
||||
from letta.functions.typescript_parser import derive_typescript_json_schema
|
||||
|
||||
schema = derive_typescript_json_schema(source_code=request.code)
|
||||
else:
|
||||
# Default to Python for backwards compatibility
|
||||
schema = derive_openai_json_schema(source_code=request.code)
|
||||
return schema
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}")
|
||||
schema = derive_typescript_json_schema(source_code=request.code)
|
||||
else:
|
||||
# Default to Python for backwards compatibility
|
||||
schema = derive_openai_json_schema(source_code=request.code)
|
||||
return schema
|
||||
|
||||
|
||||
# TODO: @jnjpng move this and other models above to appropriate file for schemas
|
||||
@@ -840,14 +726,9 @@ async def execute_mcp_tool(
|
||||
# Get the MCP server by name
|
||||
mcp_server = await server.mcp_manager.get_mcp_server(mcp_server_name, actor)
|
||||
if not mcp_server:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "MCPServerNotFound",
|
||||
"message": f"MCP server '{mcp_server_name}' not found",
|
||||
"server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
from letta.orm.errors import NoResultFound
|
||||
|
||||
raise NoResultFound(f"MCP server '{mcp_server_name}' not found")
|
||||
|
||||
# Create client and connect
|
||||
server_config = mcp_server.to_config()
|
||||
@@ -862,19 +743,6 @@ async def execute_mcp_tool(
|
||||
"result": result,
|
||||
"success": success,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Error executing MCP tool: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"code": "MCPToolExecutionError",
|
||||
"message": f"Failed to execute MCP tool: {str(e)}",
|
||||
"server_name": mcp_server_name,
|
||||
"tool_name": tool_name,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
if client:
|
||||
try:
|
||||
@@ -944,70 +812,64 @@ async def generate_tool_from_prompt(
|
||||
"""
|
||||
Generate a tool from the given user prompt.
|
||||
"""
|
||||
response_data = None
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
llm_config = await server.get_cached_llm_config_async(actor=actor, handle=request.handle or "anthropic/claude-3-5-sonnet-20240620")
|
||||
formatted_prompt = (
|
||||
f"Generate a python function named {request.tool_name} using the instructions below "
|
||||
+ (f"based on this starter code: \n\n```\n{request.starter_code}\n```\n\n" if request.starter_code else "\n")
|
||||
+ (f"Note the following validation errors: \n{' '.join(request.validation_errors)}\n\n" if request.validation_errors else "\n")
|
||||
+ f"Instructions: {request.prompt}"
|
||||
)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=llm_config.model_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
assert llm_client is not None
|
||||
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
llm_config = await server.get_cached_llm_config_async(actor=actor, handle=request.handle or "anthropic/claude-3-5-sonnet-20240620")
|
||||
formatted_prompt = (
|
||||
f"Generate a python function named {request.tool_name} using the instructions below "
|
||||
+ (f"based on this starter code: \n\n```\n{request.starter_code}\n```\n\n" if request.starter_code else "\n")
|
||||
+ (f"Note the following validation errors: \n{' '.join(request.validation_errors)}\n\n" if request.validation_errors else "\n")
|
||||
+ f"Instructions: {request.prompt}"
|
||||
)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=llm_config.model_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
assert llm_client is not None
|
||||
assistant_message_ack = "Understood, I will respond with generated python source code and sample arguments that can be used to test the functionality once I receive the user prompt. I'm ready."
|
||||
|
||||
assistant_message_ack = "Understood, I will respond with generated python source code and sample arguments that can be used to test the functionality once I receive the user prompt. I'm ready."
|
||||
input_messages = [
|
||||
Message(role=MessageRole.system, content=[TextContent(text=get_system_text("memgpt_generate_tool"))]),
|
||||
Message(role=MessageRole.assistant, content=[TextContent(text=assistant_message_ack)]),
|
||||
Message(role=MessageRole.user, content=[TextContent(text=formatted_prompt)]),
|
||||
]
|
||||
|
||||
input_messages = [
|
||||
Message(role=MessageRole.system, content=[TextContent(text=get_system_text("memgpt_generate_tool"))]),
|
||||
Message(role=MessageRole.assistant, content=[TextContent(text=assistant_message_ack)]),
|
||||
Message(role=MessageRole.user, content=[TextContent(text=formatted_prompt)]),
|
||||
]
|
||||
|
||||
tool = {
|
||||
"name": "generate_tool",
|
||||
"description": "This method generates the raw source code for a custom tool that can be attached to and agent for llm invocation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"raw_source_code": {"type": "string", "description": "The raw python source code of the custom tool."},
|
||||
"sample_args_json": {
|
||||
"type": "string",
|
||||
"description": "The JSON dict that contains sample args for a test run of the python function. Key is the name of the function parameter and value is an example argument that is passed in.",
|
||||
},
|
||||
"pip_requirements_json": {
|
||||
"type": "string",
|
||||
"description": "Optional JSON dict that contains pip packages to be installed if needed by the source code. Key is the name of the pip package and value is the version number.",
|
||||
},
|
||||
tool = {
|
||||
"name": "generate_tool",
|
||||
"description": "This method generates the raw source code for a custom tool that can be attached to and agent for llm invocation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"raw_source_code": {"type": "string", "description": "The raw python source code of the custom tool."},
|
||||
"sample_args_json": {
|
||||
"type": "string",
|
||||
"description": "The JSON dict that contains sample args for a test run of the python function. Key is the name of the function parameter and value is an example argument that is passed in.",
|
||||
},
|
||||
"pip_requirements_json": {
|
||||
"type": "string",
|
||||
"description": "Optional JSON dict that contains pip packages to be installed if needed by the source code. Key is the name of the pip package and value is the version number.",
|
||||
},
|
||||
"required": ["raw_source_code", "sample_args_json", "pip_requirements_json"],
|
||||
},
|
||||
}
|
||||
request_data = llm_client.build_request_data(
|
||||
AgentType.letta_v1_agent,
|
||||
input_messages,
|
||||
llm_config,
|
||||
tools=[tool],
|
||||
)
|
||||
response_data = await llm_client.request_async(request_data, llm_config)
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
output = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
|
||||
pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()]
|
||||
return GenerateToolOutput(
|
||||
tool=Tool(
|
||||
name=request.tool_name,
|
||||
source_type="python",
|
||||
source_code=output["raw_source_code"],
|
||||
pip_requirements=pip_requirements,
|
||||
),
|
||||
sample_args=json.loads(output["sample_args_json"]),
|
||||
response=response.choices[0].message.content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate tool: {str(e)}. Raw response: {response_data}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to generate tool: {str(e)}")
|
||||
"required": ["raw_source_code", "sample_args_json", "pip_requirements_json"],
|
||||
},
|
||||
}
|
||||
request_data = llm_client.build_request_data(
|
||||
AgentType.letta_v1_agent,
|
||||
input_messages,
|
||||
llm_config,
|
||||
tools=[tool],
|
||||
)
|
||||
response_data = await llm_client.request_async(request_data, llm_config)
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
output = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
|
||||
pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()]
|
||||
return GenerateToolOutput(
|
||||
tool=Tool(
|
||||
name=request.tool_name,
|
||||
source_type="python",
|
||||
source_code=output["raw_source_code"],
|
||||
pip_requirements=pip_requirements,
|
||||
),
|
||||
sample_args=json.loads(output["sample_args_json"]),
|
||||
response=response.choices[0].message.content,
|
||||
)
|
||||
|
||||
@@ -163,12 +163,14 @@ async def test_add_mcp_tool_rejects_invalid_schemas():
|
||||
|
||||
# Should raise HTTPException for invalid schema
|
||||
headers = HeaderParams(actor_id="test_user")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
from letta.errors import LettaInvalidMCPSchemaError
|
||||
|
||||
with pytest.raises(LettaInvalidMCPSchemaError) as exc_info:
|
||||
await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, headers=headers)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "invalid schema" in exc_info.value.detail["message"].lower()
|
||||
assert exc_info.value.detail["health_status"] == SchemaHealth.INVALID.value
|
||||
assert "invalid schema" in exc_info.value.message.lower()
|
||||
assert exc_info.value.details["mcp_tool_name"] == "test_tool"
|
||||
assert exc_info.value.details["reasons"] == ["Missing 'type' at root level"]
|
||||
|
||||
|
||||
def test_mcp_schema_healing_for_optional_fields():
|
||||
|
||||
Reference in New Issue
Block a user