feat: Auto-refresh json_schema after tool update (#1958)

This commit is contained in:
Matthew Zhou
2024-10-30 15:05:08 -07:00
committed by GitHub
parent fc03cf89d7
commit 0784bdc854
14 changed files with 124 additions and 238 deletions

View File

@@ -48,8 +48,8 @@ swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provide
swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4"))
# create tools
transfer_a = swarm.client.create_tool(transfer_agent_a, terminal=True)
transfer_b = swarm.client.create_tool(transfer_agent_b, terminal=True)
transfer_a = swarm.client.create_tool(transfer_agent_a)
transfer_b = swarm.client.create_tool(transfer_agent_b)
# create agents
if swarm.client.get_agent_id("agentb"):

View File

@@ -205,6 +205,7 @@ class AbstractClient(object):
self,
id: str,
name: Optional[str] = None,
description: Optional[str] = None,
func: Optional[Callable] = None,
tags: Optional[List[str]] = None,
) -> Tool:
@@ -1302,6 +1303,7 @@ class RESTClient(AbstractClient):
self,
id: str,
name: Optional[str] = None,
description: Optional[str] = None,
func: Optional[Callable] = None,
tags: Optional[List[str]] = None,
) -> Tool:
@@ -1324,7 +1326,7 @@ class RESTClient(AbstractClient):
source_type = "python"
request = ToolUpdate(source_type=source_type, source_code=source_code, tags=tags, name=name)
request = ToolUpdate(description=description, source_type=source_type, source_code=source_code, tags=tags, name=name)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update tool: {response.text}")
@@ -2233,7 +2235,6 @@ class LocalClient(AbstractClient):
def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
tool_create = ToolCreate.from_langchain(
langchain_tool=langchain_tool,
organization_id=self.org_id,
additional_imports_module_attr_map=additional_imports_module_attr_map,
)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
@@ -2242,12 +2243,11 @@ class LocalClient(AbstractClient):
tool_create = ToolCreate.from_crewai(
crewai_tool=crewai_tool,
additional_imports_module_attr_map=additional_imports_module_attr_map,
organization_id=self.org_id,
)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
def load_composio_tool(self, action: "ActionType") -> Tool:
tool_create = ToolCreate.from_composio(action=action, organization_id=self.org_id)
tool_create = ToolCreate.from_composio(action=action)
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
# TODO: Use the above function `add_tool` here as there is duplicate logic
@@ -2257,7 +2257,6 @@ class LocalClient(AbstractClient):
name: Optional[str] = None,
update: Optional[bool] = True, # TODO: actually use this
tags: Optional[List[str]] = None,
terminal: Optional[bool] = False,
) -> Tool:
"""
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
@@ -2267,7 +2266,6 @@ class LocalClient(AbstractClient):
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
update (bool, optional): Update the tool if it already exists. Defaults to True.
terminal (bool, optional): Whether the tool is a terminal tool (no more agent steps). Defaults to False.
Returns:
tool (Tool): The created tool.
@@ -2287,7 +2285,6 @@ class LocalClient(AbstractClient):
source_code=source_code,
name=name,
tags=tags,
terminal=terminal,
),
actor=self.user,
)
@@ -2296,6 +2293,7 @@ class LocalClient(AbstractClient):
self,
id: str,
name: Optional[str] = None,
description: Optional[str] = None,
func: Optional[callable] = None,
tags: Optional[List[str]] = None,
) -> Tool:
@@ -2316,6 +2314,7 @@ class LocalClient(AbstractClient):
"source_code": parse_source_code(func) if func else None,
"tags": tags,
"name": name,
"description": description,
}
# Filter out any None values from the dictionary

View File

@@ -7,10 +7,9 @@ from typing import Optional
from letta.constants import CLI_WARNING_PREFIX
from letta.functions.schema_generator import generate_schema
from letta.schemas.tool import ToolCreate
def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
def derive_openai_json_schema(source_code: str, name: Optional[str]) -> dict:
# auto-generate openai schema
try:
# Define a custom environment with necessary imports
@@ -19,14 +18,14 @@ def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
}
env.update(globals())
exec(tool_create.source_code, env)
exec(source_code, env)
# get available functions
functions = [f for f in env if callable(env[f])]
# TODO: not sure if this always works
func = env[functions[-1]]
json_schema = generate_schema(func, terminal=tool_create.terminal, name=tool_create.name)
json_schema = generate_schema(func, name=name)
return json_schema
except Exception as e:
raise RuntimeError(f"Failed to execute source code: {e}")
@@ -51,7 +50,7 @@ def load_function_set(module: ModuleType) -> dict:
if attr_name in function_dict:
raise ValueError(f"Found a duplicate of function name '{attr_name}'")
generated_schema = generate_schema(attr, terminal=False)
generated_schema = generate_schema(attr)
function_dict[attr_name] = {
"module": inspect.getsource(module),
"python_function": attr,

View File

@@ -74,7 +74,7 @@ def pydantic_model_to_open_ai(model):
}
def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None) -> dict:
def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None) -> dict:
# Get the signature of the function
sig = inspect.signature(function)
@@ -128,7 +128,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No
# append the heartbeat
# TODO: don't hard-code
if function.__name__ not in ["send_message", "pause_heartbeats"] and not terminal:
if function.__name__ not in ["send_message", "pause_heartbeats"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",

View File

@@ -38,8 +38,5 @@ class Tool(SqlalchemyBase, OrganizationMixin):
String, nullable=True, doc="the module path from which this tool was derived in the codebase."
)
# TODO: add terminal here eventually
# This was an intentional decision by Sarah
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")

View File

@@ -10,7 +10,6 @@ from letta.functions.helpers import (
from letta.functions.schema_generator import generate_schema_from_args_schema
from letta.schemas.letta_base import LettaBase
from letta.schemas.openai.chat_completions import ToolCall
from letta.services.organization_manager import OrganizationManager
class BaseTool(LettaBase):
@@ -69,10 +68,9 @@ class ToolCreate(LettaBase):
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).")
@classmethod
def from_composio(cls, action: "ActionType", organization_id: str = OrganizationManager.DEFAULT_ORG_ID) -> "ToolCreate":
def from_composio(cls, action: "ActionType") -> "ToolCreate":
"""
Class method to create an instance of Letta-compatible Composio Tool.
Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio
@@ -114,7 +112,6 @@ class ToolCreate(LettaBase):
cls,
langchain_tool: "LangChainBaseTool",
additional_imports_module_attr_map: dict[str, str] = None,
organization_id: str = OrganizationManager.DEFAULT_ORG_ID,
) -> "ToolCreate":
"""
Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools).
@@ -147,7 +144,6 @@ class ToolCreate(LettaBase):
cls,
crewai_tool: "CrewAIBaseTool",
additional_imports_module_attr_map: dict[str, str] = None,
organization_id: str = OrganizationManager.DEFAULT_ORG_ID,
) -> "ToolCreate":
"""
Class method to create an instance of Tool from a crewAI BaseTool object.
@@ -212,5 +208,7 @@ class ToolUpdate(LettaBase):
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
module: Optional[str] = Field(None, description="The source code of the function.")
source_code: Optional[str] = Field(None, description="The source code of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)

View File

@@ -1,21 +0,0 @@
from typing import List
from fastapi import APIRouter
from letta.schemas.agent import AgentState
from letta.server.rest_api.interface import QueuingInterface
from letta.server.server import SyncServer
router = APIRouter()
def setup_agents_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents", tags=["agents"], response_model=List[AgentState])
def get_all_agents():
"""
Get a list of all agents in the database
"""
interface.clear()
return server.list_agents()
return router

View File

@@ -1,82 +0,0 @@
from typing import List, Literal, Optional
from fastapi import APIRouter, Body, HTTPException
from pydantic import BaseModel, Field
from letta.schemas.tool import Tool as ToolModel # TODO: modify
from letta.server.rest_api.interface import QueuingInterface
from letta.server.server import SyncServer
router = APIRouter()
class ListToolsResponse(BaseModel):
tools: List[ToolModel] = Field(..., description="List of tools (functions).")
class CreateToolRequest(BaseModel):
json_schema: dict = Field(..., description="JSON schema of the tool.")
source_code: str = Field(..., description="The source code of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
class CreateToolResponse(BaseModel):
tool: ToolModel = Field(..., description="Information about the newly created tool.")
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
@router.delete("/tools/{tool_name}", tags=["tools"])
async def delete_tool(
tool_name: str,
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
):
"""
Delete a tool by name
"""
# Clear the interface
interface.clear()
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
server.ms.delete_tool(name=tool_name, user_id=None)
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
async def get_tool(tool_name: str):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
tool = server.ms.get_tool(tool_name=tool_name, user_id=None)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools():
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
tools = server.ms.list_tools(user_id=None)
return ListToolsResponse(tools=tools)
@router.post("/tools", tags=["tools"], response_model=ToolModel)
async def create_tool(
request: CreateToolRequest = Body(...),
):
"""
Create a new tool
"""
try:
return server.create_tool(
json_schema=request.json_schema, source_code=request.source_code, source_type=request.source_type, tags=request.tags
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
return router

View File

@@ -1,98 +0,0 @@
from typing import List, Optional
from fastapi import APIRouter, Body, HTTPException, Query
from letta.schemas.api_key import APIKey, APIKeyCreate
from letta.schemas.user import User, UserCreate
from letta.server.rest_api.interface import QueuingInterface
from letta.server.server import SyncServer
router = APIRouter()
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/users", tags=["admin"], response_model=List[User])
def get_all_users(cursor: Optional[str] = Query(None), limit: Optional[int] = Query(50)):
"""
Get a list of all users in the database
"""
try:
# TODO: make this call a server function
_, users = server.ms.get_all_users(cursor=cursor, limit=limit)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return users
@router.post("/users", tags=["admin"], response_model=User)
def create_user(request: UserCreate = Body(...)):
"""
Create a new user in the database
"""
try:
user = server.user_manager.create_user(request)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return user
@router.delete("/users", tags=["admin"], response_model=User)
def delete_user(
user_id: str = Query(..., description="The user_id key to be deleted."),
):
# TODO make a soft deletion, instead of a hard deletion
try:
user = server.ms.get_user(user_id=user_id)
if user is None:
raise HTTPException(status_code=404, detail=f"User does not exist")
server.ms.delete_user(user_id=user_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return user
@router.post("/users/keys", tags=["admin"], response_model=APIKey)
def create_new_api_key(request: APIKeyCreate = Body(...)):
"""
Create a new API key for a user
"""
try:
api_key = server.create_api_key(request)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return api_key
@router.get("/users/keys", tags=["admin"], response_model=List[APIKey])
def get_api_keys(
user_id: str = Query(..., description="The unique identifier of the user."),
):
"""
Get a list of all API keys for a user
"""
try:
if server.ms.get_user(user_id=user_id) is None:
raise HTTPException(status_code=404, detail=f"User does not exist")
api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return api_keys
@router.delete("/users/keys", tags=["admin"], response_model=APIKey)
def delete_api_key(
api_key: str = Query(..., description="The API key to be deleted."),
):
try:
return server.delete_api_key(api_key)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router

View File

@@ -820,7 +820,7 @@ class SyncServer(Server):
continue
source_code = parse_source_code(func)
# memory functions are not terminal
json_schema = generate_schema(func, terminal=False, name=func_name)
json_schema = generate_schema(func, name=func_name)
source_type = "python"
tags = ["memory", "memgpt-base"]
tool = self.tool_manager.create_or_update_tool(

View File

@@ -28,7 +28,9 @@ class ToolManager:
def create_or_update_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Derive json_schema
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(tool_create)
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(
source_code=tool_create.source_code, name=tool_create.name
)
derived_name = tool_create.name or derived_json_schema["name"]
try:
@@ -36,7 +38,7 @@ class ToolManager:
# This is important, because even if it's a different user, adding the same tool to the org should not happen
tool = self.get_tool_by_name(tool_name=derived_name, actor=actor)
# Put to dict and remove fields that should not be reset
update_data = tool_create.model_dump(exclude={"module", "terminal"}, exclude_unset=True)
update_data = tool_create.model_dump(exclude={"module"}, exclude_unset=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value}
@@ -59,8 +61,7 @@ class ToolManager:
"""Create a new tool based on the ToolCreate schema."""
# Create the tool
with self.session_maker() as session:
# Include all fields except `terminal` (which is not part of ToolModel) at the moment
create_data = tool_create.model_dump(exclude={"terminal"})
create_data = tool_create.model_dump()
tool = ToolModel(**create_data, organization_id=actor.organization_id) # Unpack everything directly into ToolModel
tool.create(session, actor=actor)
@@ -106,6 +107,22 @@ class ToolManager:
for key, value in update_data.items():
setattr(tool, key, value)
# If source code is changed and a new json_schema is not provided, we want to auto-refresh the name and schema
# CAUTION: This will override any name/schema values the user passed in
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
pydantic_tool = tool.to_pydantic()
# Decide whether or not to reset name
# If name was not explicitly passed in as part of the update, then we auto-generate a new name based on source code
name = None
if "name" in update_data.keys():
name = update_data["name"]
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code, name=name)
# The name will either be set (if explicit) or autogenerated from the source code
tool.name = new_schema["name"]
tool.json_schema = new_schema
# Save the updated tool to the database
tool.update(db_session=session, actor=actor)

View File

@@ -49,7 +49,7 @@ def tool_fixture(server: SyncServer):
user = server.user_manager.create_default_user()
other_user = server.user_manager.create_user(UserCreate(name="other", organization_id=org.id))
tool_create = ToolCreate(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(tool_create)
derived_json_schema = derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name)
derived_name = derived_json_schema["name"]
tool_create.json_schema = derived_json_schema
tool_create.name = derived_name
@@ -182,7 +182,7 @@ def test_create_tool(server: SyncServer, tool_fixture):
assert tool.tags == tool_create.tags
assert tool.source_code == tool_create.source_code
assert tool.source_type == tool_create.source_type
assert tool.json_schema == derive_openai_json_schema(tool_create)
assert tool.json_schema == derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name)
def test_get_tool_by_id(server: SyncServer, tool_fixture):
@@ -220,7 +220,6 @@ def test_get_tool_with_actor(server: SyncServer, tool_fixture):
def test_list_tools(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
tool_fixture["organization"]
user = tool_fixture["user"]
# List tools (should include the one created by the fixture)
@@ -249,6 +248,85 @@ def test_update_tool_by_id(server: SyncServer, tool_fixture):
assert updated_tool.description == updated_description
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture):
def counter_tool(counter: int):
"""
Args:
counter (int): The counter to count to.
Returns:
bool: If it successfully counted to the counter.
"""
for c in range(counter):
print(c)
return True
# Test begins
tool = tool_fixture["tool"]
user = tool_fixture["user"]
og_json_schema = tool_fixture["tool_create"].json_schema
source_code = parse_source_code(counter_tool)
# Create a ToolUpdate object to modify the tool's source_code
tool_update = ToolUpdate(source_code=source_code)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
# Assertions to check if the update was successful, and json_schema is updated as well
assert updated_tool.source_code == source_code
assert updated_tool.json_schema != og_json_schema
new_schema = derive_openai_json_schema(source_code=updated_tool.source_code, name=updated_tool.name)
assert updated_tool.json_schema == new_schema
assert updated_tool.name == new_schema["name"]
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture):
def counter_tool(counter: int):
"""
Args:
counter (int): The counter to count to.
Returns:
bool: If it successfully counted to the counter.
"""
for c in range(counter):
print(c)
return True
# Test begins
tool = tool_fixture["tool"]
user = tool_fixture["user"]
og_json_schema = tool_fixture["tool_create"].json_schema
source_code = parse_source_code(counter_tool)
name = "test_function_name_explicit"
# Create a ToolUpdate object to modify the tool's source_code
tool_update = ToolUpdate(name=name, source_code=source_code)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
# Assertions to check if the update was successful, and json_schema is updated as well
assert updated_tool.source_code == source_code
assert updated_tool.json_schema != og_json_schema
new_schema = derive_openai_json_schema(source_code=updated_tool.source_code, name=updated_tool.name)
assert updated_tool.json_schema == new_schema
assert updated_tool.name == name
def test_update_tool_multi_user(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
user = tool_fixture["user"]
@@ -272,7 +350,6 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture):
def test_delete_tool_by_id(server: SyncServer, tool_fixture):
tool = tool_fixture["tool"]
tool_fixture["organization"]
user = tool_fixture["user"]
# Delete the tool using the manager method

View File

@@ -42,21 +42,21 @@ def test_schema_generator():
"required": ["message"],
},
}
generated_schema = generate_schema(send_message, terminal=True)
generated_schema = generate_schema(send_message)
print(f"\n\nreference_schema={correct_schema}")
print(f"\n\ngenerated_schema={generated_schema}")
assert correct_schema == generated_schema
# Check that missing types results in an error
try:
_ = generate_schema(send_message_missing_types, terminal=True)
_ = generate_schema(send_message_missing_types)
assert False
except:
pass
# Check that missing docstring results in an error
try:
_ = generate_schema(send_message_missing_docstring, terminal=True)
_ = generate_schema(send_message_missing_docstring)
assert False
except:
pass