fix: make upsert base tools async (#2255)
This commit is contained in:
@@ -189,7 +189,7 @@ async def upsert_base_tools(
|
||||
Upsert base tools
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return server.tool_manager.upsert_base_tools(actor=actor)
|
||||
return await server.tool_manager.upsert_base_tools_async(actor=actor)
|
||||
|
||||
|
||||
@router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
@@ -368,3 +369,68 @@ class ToolManager:
|
||||
|
||||
# TODO: Delete any base tools that are stale
|
||||
return tools
|
||||
|
||||
@enforce_types
|
||||
async def upsert_base_tools_async(self, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""Add default tools in base.py and multi_agent.py"""
|
||||
functions_to_schema = {}
|
||||
module_names = ["base", "multi_agent", "voice", "builtin"]
|
||||
|
||||
for module_name in module_names:
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
raise e
|
||||
|
||||
try:
|
||||
# Load the function set
|
||||
functions_to_schema.update(load_function_set(module))
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
warnings.warn(err)
|
||||
|
||||
# create tool in db
|
||||
tools = []
|
||||
for name, schema in functions_to_schema.items():
|
||||
if name in LETTA_TOOL_SET:
|
||||
if name in BASE_TOOLS:
|
||||
tool_type = ToolType.LETTA_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BASE_MEMORY_TOOLS:
|
||||
tool_type = ToolType.LETTA_MEMORY_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in MULTI_AGENT_TOOLS:
|
||||
tool_type = ToolType.LETTA_MULTI_AGENT_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BASE_SLEEPTIME_TOOLS:
|
||||
tool_type = ToolType.LETTA_SLEEPTIME_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS:
|
||||
tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BUILTIN_TOOLS:
|
||||
tool_type = ToolType.LETTA_BUILTIN
|
||||
tags = [tool_type.value]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}"
|
||||
)
|
||||
|
||||
# create to tool
|
||||
tools.append(
|
||||
self.create_or_update_tool_async(
|
||||
PydanticTool(
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
tool_type=tool_type,
|
||||
return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Delete any base tools that are stale
|
||||
return await asyncio.gather(*tools)
|
||||
|
||||
@@ -2418,14 +2418,15 @@ async def test_delete_tool_by_id(server: SyncServer, print_tool, default_user, e
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
def test_upsert_base_tools(server: SyncServer, default_user):
|
||||
tools = server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_base_tools(server: SyncServer, default_user, event_loop):
|
||||
tools = await server.tool_manager.upsert_base_tools_async(actor=default_user)
|
||||
expected_tool_names = sorted(LETTA_TOOL_SET)
|
||||
|
||||
assert sorted([t.name for t in tools]) == expected_tool_names
|
||||
|
||||
# Call it again to make sure it doesn't create duplicates
|
||||
tools = server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
tools = await server.tool_manager.upsert_base_tools_async(actor=default_user)
|
||||
assert sorted([t.name for t in tools]) == expected_tool_names
|
||||
|
||||
# Confirm that the return tools have no source_code, but a json_schema
|
||||
|
||||
Reference in New Issue
Block a user