diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 4b837b4a..073e1e15 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -189,7 +189,7 @@ async def upsert_base_tools( Upsert base tools """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return server.tool_manager.upsert_base_tools(actor=actor) + return await server.tool_manager.upsert_base_tools_async(actor=actor) @router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source") diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 88ca95a5..9e7bf42f 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -1,3 +1,4 @@ +import asyncio import importlib import warnings from typing import List, Optional @@ -368,3 +369,68 @@ class ToolManager: # TODO: Delete any base tools that are stale return tools + + @enforce_types + async def upsert_base_tools_async(self, actor: PydanticUser) -> List[PydanticTool]: + """Add default tools in base.py and multi_agent.py""" + functions_to_schema = {} + module_names = ["base", "multi_agent", "voice", "builtin"] + + for module_name in module_names: + full_module_name = f"letta.functions.function_sets.{module_name}" + try: + module = importlib.import_module(full_module_name) + except Exception as e: + # Handle other general exceptions + raise e + + try: + # Load the function set + functions_to_schema.update(load_function_set(module)) + except ValueError as e: + err = f"Error loading function set '{module_name}': {e}" + warnings.warn(err) + + # create tool in db + tools = [] + for name, schema in functions_to_schema.items(): + if name in LETTA_TOOL_SET: + if name in BASE_TOOLS: + tool_type = ToolType.LETTA_CORE + tags = [tool_type.value] + elif name in BASE_MEMORY_TOOLS: + tool_type = ToolType.LETTA_MEMORY_CORE + tags = [tool_type.value] + elif name in MULTI_AGENT_TOOLS: + tool_type = ToolType.LETTA_MULTI_AGENT_CORE + tags = [tool_type.value] + elif name in BASE_SLEEPTIME_TOOLS: + tool_type = ToolType.LETTA_SLEEPTIME_CORE + tags = [tool_type.value] + elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS: + tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE + tags = [tool_type.value] + elif name in BUILTIN_TOOLS: + tool_type = ToolType.LETTA_BUILTIN + tags = [tool_type.value] + else: + raise ValueError( + f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}" + ) + + # create to tool + tools.append( + self.create_or_update_tool_async( + PydanticTool( + name=name, + tags=tags, + source_type="python", + tool_type=tool_type, + return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT, + ), + actor=actor, + ) + ) + + # TODO: Delete any base tools that are stale + return await asyncio.gather(*tools) diff --git a/tests/test_managers.py b/tests/test_managers.py index 99fda9fd..235a12dd 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2418,14 +2418,15 @@ async def test_delete_tool_by_id(server: SyncServer, print_tool, default_user, e assert len(tools) == 0 -def test_upsert_base_tools(server: SyncServer, default_user): - tools = server.tool_manager.upsert_base_tools(actor=default_user) +@pytest.mark.asyncio +async def test_upsert_base_tools(server: SyncServer, default_user, event_loop): + tools = await server.tool_manager.upsert_base_tools_async(actor=default_user) expected_tool_names = sorted(LETTA_TOOL_SET) assert sorted([t.name for t in tools]) == expected_tool_names # Call it again to make sure it doesn't create duplicates - tools = server.tool_manager.upsert_base_tools(actor=default_user) + tools = await server.tool_manager.upsert_base_tools_async(actor=default_user) assert sorted([t.name for t in tools]) == expected_tool_names # Confirm that the return tools have no source_code, but a json_schema