feat: Auto-refresh json_schema after tool update (#1958)
This commit is contained in:
@@ -48,8 +48,8 @@ swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provide
|
||||
swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4"))
|
||||
|
||||
# create tools
|
||||
transfer_a = swarm.client.create_tool(transfer_agent_a, terminal=True)
|
||||
transfer_b = swarm.client.create_tool(transfer_agent_b, terminal=True)
|
||||
transfer_a = swarm.client.create_tool(transfer_agent_a)
|
||||
transfer_b = swarm.client.create_tool(transfer_agent_b)
|
||||
|
||||
# create agents
|
||||
if swarm.client.get_agent_id("agentb"):
|
||||
|
||||
@@ -205,6 +205,7 @@ class AbstractClient(object):
|
||||
self,
|
||||
id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
func: Optional[Callable] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Tool:
|
||||
@@ -1302,6 +1303,7 @@ class RESTClient(AbstractClient):
|
||||
self,
|
||||
id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
func: Optional[Callable] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Tool:
|
||||
@@ -1324,7 +1326,7 @@ class RESTClient(AbstractClient):
|
||||
|
||||
source_type = "python"
|
||||
|
||||
request = ToolUpdate(source_type=source_type, source_code=source_code, tags=tags, name=name)
|
||||
request = ToolUpdate(description=description, source_type=source_type, source_code=source_code, tags=tags, name=name)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update tool: {response.text}")
|
||||
@@ -2233,7 +2235,6 @@ class LocalClient(AbstractClient):
|
||||
def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
|
||||
tool_create = ToolCreate.from_langchain(
|
||||
langchain_tool=langchain_tool,
|
||||
organization_id=self.org_id,
|
||||
additional_imports_module_attr_map=additional_imports_module_attr_map,
|
||||
)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
|
||||
@@ -2242,12 +2243,11 @@ class LocalClient(AbstractClient):
|
||||
tool_create = ToolCreate.from_crewai(
|
||||
crewai_tool=crewai_tool,
|
||||
additional_imports_module_attr_map=additional_imports_module_attr_map,
|
||||
organization_id=self.org_id,
|
||||
)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
|
||||
|
||||
def load_composio_tool(self, action: "ActionType") -> Tool:
|
||||
tool_create = ToolCreate.from_composio(action=action, organization_id=self.org_id)
|
||||
tool_create = ToolCreate.from_composio(action=action)
|
||||
return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user)
|
||||
|
||||
# TODO: Use the above function `add_tool` here as there is duplicate logic
|
||||
@@ -2257,7 +2257,6 @@ class LocalClient(AbstractClient):
|
||||
name: Optional[str] = None,
|
||||
update: Optional[bool] = True, # TODO: actually use this
|
||||
tags: Optional[List[str]] = None,
|
||||
terminal: Optional[bool] = False,
|
||||
) -> Tool:
|
||||
"""
|
||||
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
|
||||
@@ -2267,7 +2266,6 @@ class LocalClient(AbstractClient):
|
||||
name: (str): Name of the tool (must be unique per-user.)
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
update (bool, optional): Update the tool if it already exists. Defaults to True.
|
||||
terminal (bool, optional): Whether the tool is a terminal tool (no more agent steps). Defaults to False.
|
||||
|
||||
Returns:
|
||||
tool (Tool): The created tool.
|
||||
@@ -2287,7 +2285,6 @@ class LocalClient(AbstractClient):
|
||||
source_code=source_code,
|
||||
name=name,
|
||||
tags=tags,
|
||||
terminal=terminal,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
@@ -2296,6 +2293,7 @@ class LocalClient(AbstractClient):
|
||||
self,
|
||||
id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
func: Optional[callable] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Tool:
|
||||
@@ -2316,6 +2314,7 @@ class LocalClient(AbstractClient):
|
||||
"source_code": parse_source_code(func) if func else None,
|
||||
"tags": tags,
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
# Filter out any None values from the dictionary
|
||||
|
||||
@@ -7,10 +7,9 @@ from typing import Optional
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
from letta.schemas.tool import ToolCreate
|
||||
|
||||
|
||||
def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
|
||||
def derive_openai_json_schema(source_code: str, name: Optional[str]) -> dict:
|
||||
# auto-generate openai schema
|
||||
try:
|
||||
# Define a custom environment with necessary imports
|
||||
@@ -19,14 +18,14 @@ def derive_openai_json_schema(tool_create: ToolCreate) -> dict:
|
||||
}
|
||||
|
||||
env.update(globals())
|
||||
exec(tool_create.source_code, env)
|
||||
exec(source_code, env)
|
||||
|
||||
# get available functions
|
||||
functions = [f for f in env if callable(env[f])]
|
||||
|
||||
# TODO: not sure if this always works
|
||||
func = env[functions[-1]]
|
||||
json_schema = generate_schema(func, terminal=tool_create.terminal, name=tool_create.name)
|
||||
json_schema = generate_schema(func, name=name)
|
||||
return json_schema
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to execute source code: {e}")
|
||||
@@ -51,7 +50,7 @@ def load_function_set(module: ModuleType) -> dict:
|
||||
if attr_name in function_dict:
|
||||
raise ValueError(f"Found a duplicate of function name '{attr_name}'")
|
||||
|
||||
generated_schema = generate_schema(attr, terminal=False)
|
||||
generated_schema = generate_schema(attr)
|
||||
function_dict[attr_name] = {
|
||||
"module": inspect.getsource(module),
|
||||
"python_function": attr,
|
||||
|
||||
@@ -74,7 +74,7 @@ def pydantic_model_to_open_ai(model):
|
||||
}
|
||||
|
||||
|
||||
def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None) -> dict:
|
||||
def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None) -> dict:
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
|
||||
@@ -128,7 +128,7 @@ def generate_schema(function, terminal: Optional[bool], name: Optional[str] = No
|
||||
|
||||
# append the heartbeat
|
||||
# TODO: don't hard-code
|
||||
if function.__name__ not in ["send_message", "pause_heartbeats"] and not terminal:
|
||||
if function.__name__ not in ["send_message", "pause_heartbeats"]:
|
||||
schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
|
||||
@@ -38,8 +38,5 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
String, nullable=True, doc="the module path from which this tool was derived in the codebase."
|
||||
)
|
||||
|
||||
# TODO: add terminal here eventually
|
||||
# This was an intentional decision by Sarah
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
||||
|
||||
@@ -10,7 +10,6 @@ from letta.functions.helpers import (
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
|
||||
class BaseTool(LettaBase):
|
||||
@@ -69,10 +68,9 @@ class ToolCreate(LettaBase):
|
||||
json_schema: Optional[Dict] = Field(
|
||||
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
|
||||
)
|
||||
terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).")
|
||||
|
||||
@classmethod
|
||||
def from_composio(cls, action: "ActionType", organization_id: str = OrganizationManager.DEFAULT_ORG_ID) -> "ToolCreate":
|
||||
def from_composio(cls, action: "ActionType") -> "ToolCreate":
|
||||
"""
|
||||
Class method to create an instance of Letta-compatible Composio Tool.
|
||||
Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio
|
||||
@@ -114,7 +112,6 @@ class ToolCreate(LettaBase):
|
||||
cls,
|
||||
langchain_tool: "LangChainBaseTool",
|
||||
additional_imports_module_attr_map: dict[str, str] = None,
|
||||
organization_id: str = OrganizationManager.DEFAULT_ORG_ID,
|
||||
) -> "ToolCreate":
|
||||
"""
|
||||
Class method to create an instance of Tool from a Langchain tool (must be from langchain_community.tools).
|
||||
@@ -147,7 +144,6 @@ class ToolCreate(LettaBase):
|
||||
cls,
|
||||
crewai_tool: "CrewAIBaseTool",
|
||||
additional_imports_module_attr_map: dict[str, str] = None,
|
||||
organization_id: str = OrganizationManager.DEFAULT_ORG_ID,
|
||||
) -> "ToolCreate":
|
||||
"""
|
||||
Class method to create an instance of Tool from a crewAI BaseTool object.
|
||||
@@ -212,5 +208,7 @@ class ToolUpdate(LettaBase):
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
module: Optional[str] = Field(None, description="The source code of the function.")
|
||||
source_code: Optional[str] = Field(None, description="The source code of the function.")
|
||||
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
json_schema: Optional[Dict] = Field(
|
||||
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
|
||||
)
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.server.rest_api.interface import QueuingInterface
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_agents_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/agents", tags=["agents"], response_model=List[AgentState])
|
||||
def get_all_agents():
|
||||
"""
|
||||
Get a list of all agents in the database
|
||||
"""
|
||||
interface.clear()
|
||||
return server.list_agents()
|
||||
|
||||
return router
|
||||
@@ -1,82 +0,0 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.tool import Tool as ToolModel # TODO: modify
|
||||
from letta.server.rest_api.interface import QueuingInterface
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
tools: List[ToolModel] = Field(..., description="List of tools (functions).")
|
||||
|
||||
|
||||
class CreateToolRequest(BaseModel):
|
||||
json_schema: dict = Field(..., description="JSON schema of the tool.")
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
|
||||
|
||||
class CreateToolResponse(BaseModel):
|
||||
tool: ToolModel = Field(..., description="Information about the newly created tool.")
|
||||
|
||||
|
||||
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
|
||||
|
||||
@router.delete("/tools/{tool_name}", tags=["tools"])
|
||||
async def delete_tool(
|
||||
tool_name: str,
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
server.ms.delete_tool(name=tool_name, user_id=None)
|
||||
|
||||
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
|
||||
async def get_tool(tool_name: str):
|
||||
"""
|
||||
Get a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tool = server.ms.get_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
tool = server.ms.get_tool(tool_name=tool_name, user_id=None)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
|
||||
return tool
|
||||
|
||||
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
||||
async def list_all_tools():
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
# tools = server.ms.list_tools(user_id=user_id) TODO: add back when user-specific
|
||||
tools = server.ms.list_tools(user_id=None)
|
||||
return ListToolsResponse(tools=tools)
|
||||
|
||||
@router.post("/tools", tags=["tools"], response_model=ToolModel)
|
||||
async def create_tool(
|
||||
request: CreateToolRequest = Body(...),
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
try:
|
||||
return server.create_tool(
|
||||
json_schema=request.json_schema, source_code=request.source_code, source_type=request.source_type, tags=request.tags
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
|
||||
|
||||
return router
|
||||
@@ -1,98 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Query
|
||||
|
||||
from letta.schemas.api_key import APIKey, APIKeyCreate
|
||||
from letta.schemas.user import User, UserCreate
|
||||
from letta.server.rest_api.interface import QueuingInterface
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/users", tags=["admin"], response_model=List[User])
|
||||
def get_all_users(cursor: Optional[str] = Query(None), limit: Optional[int] = Query(50)):
|
||||
"""
|
||||
Get a list of all users in the database
|
||||
"""
|
||||
try:
|
||||
# TODO: make this call a server function
|
||||
_, users = server.ms.get_all_users(cursor=cursor, limit=limit)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return users
|
||||
|
||||
@router.post("/users", tags=["admin"], response_model=User)
|
||||
def create_user(request: UserCreate = Body(...)):
|
||||
"""
|
||||
Create a new user in the database
|
||||
"""
|
||||
try:
|
||||
user = server.user_manager.create_user(request)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return user
|
||||
|
||||
@router.delete("/users", tags=["admin"], response_model=User)
|
||||
def delete_user(
|
||||
user_id: str = Query(..., description="The user_id key to be deleted."),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
user = server.ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
server.ms.delete_user(user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return user
|
||||
|
||||
@router.post("/users/keys", tags=["admin"], response_model=APIKey)
|
||||
def create_new_api_key(request: APIKeyCreate = Body(...)):
|
||||
"""
|
||||
Create a new API key for a user
|
||||
"""
|
||||
try:
|
||||
api_key = server.create_api_key(request)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return api_key
|
||||
|
||||
@router.get("/users/keys", tags=["admin"], response_model=List[APIKey])
|
||||
def get_api_keys(
|
||||
user_id: str = Query(..., description="The unique identifier of the user."),
|
||||
):
|
||||
"""
|
||||
Get a list of all API keys for a user
|
||||
"""
|
||||
try:
|
||||
if server.ms.get_user(user_id=user_id) is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return api_keys
|
||||
|
||||
@router.delete("/users/keys", tags=["admin"], response_model=APIKey)
|
||||
def delete_api_key(
|
||||
api_key: str = Query(..., description="The API key to be deleted."),
|
||||
):
|
||||
try:
|
||||
return server.delete_api_key(api_key)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
return router
|
||||
@@ -820,7 +820,7 @@ class SyncServer(Server):
|
||||
continue
|
||||
source_code = parse_source_code(func)
|
||||
# memory functions are not terminal
|
||||
json_schema = generate_schema(func, terminal=False, name=func_name)
|
||||
json_schema = generate_schema(func, name=func_name)
|
||||
source_type = "python"
|
||||
tags = ["memory", "memgpt-base"]
|
||||
tool = self.tool_manager.create_or_update_tool(
|
||||
|
||||
@@ -28,7 +28,9 @@ class ToolManager:
|
||||
def create_or_update_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Derive json_schema
|
||||
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(tool_create)
|
||||
derived_json_schema = tool_create.json_schema or derive_openai_json_schema(
|
||||
source_code=tool_create.source_code, name=tool_create.name
|
||||
)
|
||||
derived_name = tool_create.name or derived_json_schema["name"]
|
||||
|
||||
try:
|
||||
@@ -36,7 +38,7 @@ class ToolManager:
|
||||
# This is important, because even if it's a different user, adding the same tool to the org should not happen
|
||||
tool = self.get_tool_by_name(tool_name=derived_name, actor=actor)
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = tool_create.model_dump(exclude={"module", "terminal"}, exclude_unset=True)
|
||||
update_data = tool_create.model_dump(exclude={"module"}, exclude_unset=True)
|
||||
# Remove redundant update fields
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value}
|
||||
|
||||
@@ -59,8 +61,7 @@ class ToolManager:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Create the tool
|
||||
with self.session_maker() as session:
|
||||
# Include all fields except `terminal` (which is not part of ToolModel) at the moment
|
||||
create_data = tool_create.model_dump(exclude={"terminal"})
|
||||
create_data = tool_create.model_dump()
|
||||
tool = ToolModel(**create_data, organization_id=actor.organization_id) # Unpack everything directly into ToolModel
|
||||
tool.create(session, actor=actor)
|
||||
|
||||
@@ -106,6 +107,22 @@ class ToolManager:
|
||||
for key, value in update_data.items():
|
||||
setattr(tool, key, value)
|
||||
|
||||
# If source code is changed and a new json_schema is not provided, we want to auto-refresh the name and schema
|
||||
# CAUTION: This will override any name/schema values the user passed in
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
pydantic_tool = tool.to_pydantic()
|
||||
|
||||
# Decide whether or not to reset name
|
||||
# If name was not explicitly passed in as part of the update, then we auto-generate a new name based on source code
|
||||
name = None
|
||||
if "name" in update_data.keys():
|
||||
name = update_data["name"]
|
||||
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code, name=name)
|
||||
|
||||
# The name will either be set (if explicit) or autogenerated from the source code
|
||||
tool.name = new_schema["name"]
|
||||
tool.json_schema = new_schema
|
||||
|
||||
# Save the updated tool to the database
|
||||
tool.update(db_session=session, actor=actor)
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def tool_fixture(server: SyncServer):
|
||||
user = server.user_manager.create_default_user()
|
||||
other_user = server.user_manager.create_user(UserCreate(name="other", organization_id=org.id))
|
||||
tool_create = ToolCreate(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
derived_json_schema = derive_openai_json_schema(tool_create)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name)
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool_create.json_schema = derived_json_schema
|
||||
tool_create.name = derived_name
|
||||
@@ -182,7 +182,7 @@ def test_create_tool(server: SyncServer, tool_fixture):
|
||||
assert tool.tags == tool_create.tags
|
||||
assert tool.source_code == tool_create.source_code
|
||||
assert tool.source_type == tool_create.source_type
|
||||
assert tool.json_schema == derive_openai_json_schema(tool_create)
|
||||
assert tool.json_schema == derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name)
|
||||
|
||||
|
||||
def test_get_tool_by_id(server: SyncServer, tool_fixture):
|
||||
@@ -220,7 +220,6 @@ def test_get_tool_with_actor(server: SyncServer, tool_fixture):
|
||||
|
||||
def test_list_tools(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
tool_fixture["organization"]
|
||||
user = tool_fixture["user"]
|
||||
|
||||
# List tools (should include the one created by the fixture)
|
||||
@@ -249,6 +248,85 @@ def test_update_tool_by_id(server: SyncServer, tool_fixture):
|
||||
assert updated_tool.description == updated_description
|
||||
|
||||
|
||||
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture):
|
||||
def counter_tool(counter: int):
|
||||
"""
|
||||
Args:
|
||||
counter (int): The counter to count to.
|
||||
|
||||
Returns:
|
||||
bool: If it successfully counted to the counter.
|
||||
"""
|
||||
for c in range(counter):
|
||||
print(c)
|
||||
|
||||
return True
|
||||
|
||||
# Test begins
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
og_json_schema = tool_fixture["tool_create"].json_schema
|
||||
|
||||
source_code = parse_source_code(counter_tool)
|
||||
|
||||
# Create a ToolUpdate object to modify the tool's source_code
|
||||
tool_update = ToolUpdate(source_code=source_code)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
|
||||
|
||||
# Assertions to check if the update was successful, and json_schema is updated as well
|
||||
assert updated_tool.source_code == source_code
|
||||
assert updated_tool.json_schema != og_json_schema
|
||||
|
||||
new_schema = derive_openai_json_schema(source_code=updated_tool.source_code, name=updated_tool.name)
|
||||
assert updated_tool.json_schema == new_schema
|
||||
assert updated_tool.name == new_schema["name"]
|
||||
|
||||
|
||||
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture):
|
||||
def counter_tool(counter: int):
|
||||
"""
|
||||
Args:
|
||||
counter (int): The counter to count to.
|
||||
|
||||
Returns:
|
||||
bool: If it successfully counted to the counter.
|
||||
"""
|
||||
for c in range(counter):
|
||||
print(c)
|
||||
|
||||
return True
|
||||
|
||||
# Test begins
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
og_json_schema = tool_fixture["tool_create"].json_schema
|
||||
|
||||
source_code = parse_source_code(counter_tool)
|
||||
name = "test_function_name_explicit"
|
||||
|
||||
# Create a ToolUpdate object to modify the tool's source_code
|
||||
tool_update = ToolUpdate(name=name, source_code=source_code)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
|
||||
|
||||
# Assertions to check if the update was successful, and json_schema is updated as well
|
||||
assert updated_tool.source_code == source_code
|
||||
assert updated_tool.json_schema != og_json_schema
|
||||
|
||||
new_schema = derive_openai_json_schema(source_code=updated_tool.source_code, name=updated_tool.name)
|
||||
assert updated_tool.json_schema == new_schema
|
||||
assert updated_tool.name == name
|
||||
|
||||
|
||||
def test_update_tool_multi_user(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
@@ -272,7 +350,6 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture):
|
||||
|
||||
def test_delete_tool_by_id(server: SyncServer, tool_fixture):
|
||||
tool = tool_fixture["tool"]
|
||||
tool_fixture["organization"]
|
||||
user = tool_fixture["user"]
|
||||
|
||||
# Delete the tool using the manager method
|
||||
|
||||
@@ -42,21 +42,21 @@ def test_schema_generator():
|
||||
"required": ["message"],
|
||||
},
|
||||
}
|
||||
generated_schema = generate_schema(send_message, terminal=True)
|
||||
generated_schema = generate_schema(send_message)
|
||||
print(f"\n\nreference_schema={correct_schema}")
|
||||
print(f"\n\ngenerated_schema={generated_schema}")
|
||||
assert correct_schema == generated_schema
|
||||
|
||||
# Check that missing types results in an error
|
||||
try:
|
||||
_ = generate_schema(send_message_missing_types, terminal=True)
|
||||
_ = generate_schema(send_message_missing_types)
|
||||
assert False
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check that missing docstring results in an error
|
||||
try:
|
||||
_ = generate_schema(send_message_missing_docstring, terminal=True)
|
||||
_ = generate_schema(send_message_missing_docstring)
|
||||
assert False
|
||||
except:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user