fix: pass tool errors through the HTTP POST /tools requests with detailed error messages (#2110)

This commit is contained in:
Charles Packer
2024-11-26 17:06:44 -08:00
committed by GitHub
parent 056cbb0eec
commit cfb48a112f
4 changed files with 36 additions and 14 deletions

View File

@@ -10,6 +10,18 @@ class LettaError(Exception):
"""Base class for all Letta related errors."""
class LettaToolCreateError(LettaError):
"""Error raised when a tool cannot be created."""
default_error_message = "Error creating tool."
def __init__(self, message=None):
if message is None:
message = self.default_error_message
self.message = message
super().__init__(self.message)
class LLMError(LettaError):
pass

View File

@@ -3,9 +3,10 @@ import inspect
import os
from textwrap import dedent # remove indentation
from types import ModuleType
from typing import Optional, List
from typing import Dict, List, Optional
from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LettaToolCreateError
from letta.functions.schema_generator import generate_schema
@@ -13,10 +14,7 @@ def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> d
# auto-generate openai schema
try:
# Define a custom environment with necessary imports
env = {
"Optional": Optional, # Add any other required imports here
"List": List
}
env = {"Optional": Optional, "List": List, "Dict": Dict} # Add any other required imports here
env.update(globals())
exec(source_code, env)
@@ -29,7 +27,7 @@ def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> d
json_schema = generate_schema(func, name=name)
return json_schema
except Exception as e:
raise RuntimeError(f"Failed to execute source code: {e}")
raise LettaToolCreateError(f"Failed to derive JSON schema from source code: {e}")
def parse_source_code(func) -> str:

View File

@@ -131,11 +131,12 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
else:
# Add parameter details to the schema
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
schema["parameters"]["properties"][param.name] = {
# "type": "string" if param.annotation == str else str(param.annotation),
"type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
"description": param_doc.description,
}
if param_doc:
schema["parameters"]["properties"][param.name] = {
# "type": "string" if param.annotation == str else str(param.annotation),
"type": type_to_json_schema_type(param.annotation) if param.annotation != inspect.Parameter.empty else "string",
"description": param_doc.description,
}
if param.default == inspect.Parameter.empty:
schema["parameters"]["required"].append(param.name)

View File

@@ -2,6 +2,7 @@ from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.errors import LettaToolCreateError
from letta.orm.errors import UniqueConstraintViolationError
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.server.rest_api.utils import get_letta_server
@@ -14,12 +15,13 @@ router = APIRouter(prefix="/tools", tags=["tools"])
def delete_tool(
tool_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a tool by name
"""
# actor = server.get_user_or_default(user_id=user_id)
server.tool_manager.delete_tool(tool_id=tool_id)
actor = server.get_user_or_default(user_id=user_id)
server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor)
@router.get("/{tool_id}", response_model=Tool, operation_id="get_tool")
@@ -91,7 +93,16 @@ def create_tool(
except UniqueConstraintViolationError as e:
# Log or print the full exception here for debugging
print(f"Error occurred: {e}")
raise HTTPException(status_code=409, detail=str(e))
clean_error_message = f"Tool with name {request.name} already exists."
raise HTTPException(status_code=409, detail=clean_error_message)
except LettaToolCreateError as e:
# HTTP 400 == Bad Request
print(f"Error occurred during tool creation: {e}")
# print the full stack trace
import traceback
print(traceback.format_exc())
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# Catch other unexpected errors and raise an internal server error
print(f"Unexpected error occurred: {e}")