feat: add generate tool api (#3519)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from composio.client import ComposioClientError, HTTPError, NoItemsFound
|
||||
@@ -17,10 +18,14 @@ from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.mcp_client.exceptions import MCPTimeoutError
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
@@ -686,3 +691,88 @@ async def generate_json_schema(
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}")
|
||||
|
||||
|
||||
class GenerateToolInput(BaseModel):
|
||||
tool_name: str = Field(..., description="Name of the tool to generate code for")
|
||||
prompt: str = Field(..., description="User prompt to generate code")
|
||||
handle: Optional[str] = Field(None, description="Handle of the tool to generate code for")
|
||||
starter_code: Optional[str] = Field(None, description="Python source code to parse for JSON schema")
|
||||
validation_errors: List[str] = Field(..., description="List of validation errors")
|
||||
|
||||
|
||||
class GenerateToolOutput(BaseModel):
|
||||
tool: Tool = Field(..., description="Generated tool")
|
||||
sample_args: Dict[str, Any] = Field(..., description="Sample arguments for the tool")
|
||||
response: str = Field(..., description="Response from the assistant")
|
||||
|
||||
|
||||
@router.post("/generate-tool", response_model=GenerateToolOutput, operation_id="generate_tool")
|
||||
async def generate_tool_from_prompt(
|
||||
request: GenerateToolInput = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Generate a tool from the given user prompt.
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
llm_config = await server.get_cached_llm_config_async(actor=actor, handle=request.handle or "anthropic/claude-3-5-sonnet-20240620")
|
||||
formatted_prompt = (
|
||||
f"Generate a python function named {request.tool_name} using the instructions below "
|
||||
+ (f"based on this starter code: \n\n```\n{request.starter_code}\n```\n\n" if request.starter_code else "\n")
|
||||
+ (f"Note the following validation errors: \n{' '.join(request.validation_errors)}\n\n" if request.validation_errors else "\n")
|
||||
+ f"Instructions: {request.prompt}"
|
||||
)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=llm_config.model_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
assert llm_client is not None
|
||||
|
||||
input_messages = [
|
||||
Message(role=MessageRole.system, content=[TextContent(text="Placeholder system message")]),
|
||||
Message(role=MessageRole.assistant, content=[TextContent(text="Placeholder assistant message")]),
|
||||
Message(role=MessageRole.user, content=[TextContent(text=formatted_prompt)]),
|
||||
]
|
||||
|
||||
tool = {
|
||||
"name": "generate_tool",
|
||||
"description": "This method generates the raw source code for a custom tool that can be attached to and agent for llm invocation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"raw_source_code": {"type": "string", "description": "The raw python source code of the custom tool."},
|
||||
"sample_args_json": {
|
||||
"type": "string",
|
||||
"description": "The JSON dict that contains sample args for a test run of the python function. Key is the name of the function parameter and value is an example argument that is passed in.",
|
||||
},
|
||||
"pip_requirements_json": {
|
||||
"type": "string",
|
||||
"description": "Optional JSON dict that contains pip packages to be installed if needed by the source code. Key is the name of the pip package and value is the version number.",
|
||||
},
|
||||
},
|
||||
"required": ["raw_source_code", "sample_args_json", "pip_requirements_json"],
|
||||
},
|
||||
}
|
||||
request_data = llm_client.build_request_data(
|
||||
input_messages,
|
||||
llm_config,
|
||||
tools=[tool],
|
||||
)
|
||||
response_data = await llm_client.request_async(request_data, llm_config)
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
output = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
|
||||
return GenerateToolOutput(
|
||||
tool=Tool(
|
||||
name=request.tool_name,
|
||||
source_type="python",
|
||||
source_code=output["raw_source_code"],
|
||||
),
|
||||
sample_args=json.loads(output["sample_args_json"]),
|
||||
response=response.choices[0].message.content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate tool: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to generate tool: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user