From 91432088f91521719d7dfd9a3bb08501666daf17 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 5 Jun 2025 14:01:29 -0700 Subject: [PATCH] feat: Add auto attach and detach of file tools (#2655) --- letta/constants.py | 2 +- letta/server/rest_api/routers/v1/agents.py | 7 ++ letta/services/agent_manager.py | 60 ++++++++++ letta/services/tool_manager.py | 95 ++++++++-------- tests/test_client.py | 50 --------- tests/test_managers.py | 39 +++++++ tests/test_sources.py | 124 +++++++++++---------- 7 files changed, 219 insertions(+), 158 deletions(-) diff --git a/letta/constants.py b/letta/constants.py index 3d2ea35a..3fd2a778 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -122,7 +122,7 @@ MEMORY_TOOLS_LINE_NUMBER_PREFIX_REGEX = re.compile( BUILTIN_TOOLS = ["run_code", "web_search"] # Built in tools -FILES_TOOLS = ["web_search", "run_code", "open_file", "close_file", "grep", "search_files"] +FILES_TOOLS = ["open_file", "close_file", "grep", "search_files"] # Set of all built-in Letta tools LETTA_TOOL_SET = set( diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index f416749f..b84f7dd4 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -311,6 +311,9 @@ async def attach_source( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) agent_state = await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=source_id, actor=actor) + # Check if the agent is missing any files tools + agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=actor) + files = await server.source_manager.list_files(source_id, actor, include_content=True) texts = [] file_ids = [] @@ -345,6 +348,10 @@ async def detach_source( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) agent_state = await server.agent_manager.detach_source_async(agent_id=agent_id, source_id=source_id, actor=actor) + + if not agent_state.sources: + agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=actor) + files = await server.source_manager.list_files(source_id, actor) file_ids = [f.id for f in files] await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 4cb0af6a..6816e178 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -16,6 +16,7 @@ from letta.constants import ( BASE_VOICE_SLEEPTIME_CHAT_TOOLS, BASE_VOICE_SLEEPTIME_TOOLS, DATA_SOURCE_ATTACH_ALERT, + FILES_TOOLS, MULTI_AGENT_TOOLS, ) from letta.helpers.datetime_helpers import get_utc_time @@ -2382,6 +2383,65 @@ class AgentManager: await agent.update_async(session, actor=actor) return await agent.to_pydantic_async() + @trace_method + @enforce_types + async def attach_missing_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: + """ + Attaches missing core file tools to an agent. + + Args: + agent_id: ID of the agent to attach the tools to. + actor: User performing the action. + + Raises: + NoResultFound: If the agent or tool is not found. + + Returns: + PydanticAgentState: The updated agent state. + """ + # Check if the agent is missing any files tools + core_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} + missing_tool_names = set(FILES_TOOLS).difference(core_tool_names) + + for tool_name in missing_tool_names: + tool_id = self.tool_manager.get_tool_id_by_name(tool_name=tool_name, actor=actor) + + # TODO: This is hacky and deserves a rethink - how do we keep all the base tools available in every org always? + if not tool_id: + await self.tool_manager.upsert_base_tools_async(actor=actor, allowed_types={ToolType.LETTA_FILES_CORE}) + + # TODO: Inefficient - I think this re-retrieves the agent_state? + agent_state = await self.attach_tool_async(agent_id=agent_state.id, tool_id=tool_id, actor=actor) + + return agent_state + + @trace_method + @enforce_types + async def detach_all_files_tools_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: + """ + Detach all core file tools from an agent. + + Args: + agent_id: ID of the agent to detach the tools from. + actor: User performing the action. + + Raises: + NoResultFound: If the agent or tool is not found. + + Returns: + PydanticAgentState: The updated agent state. + """ + # Check if the agent is missing any files tools + core_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} + + for tool_name in core_tool_names: + tool_id = self.tool_manager.get_tool_id_by_name(tool_name=tool_name, actor=actor) + + # TODO: Inefficient - I think this re-retrieves the agent_state? + agent_state = await self.detach_tool_async(agent_id=agent_state.id, tool_id=tool_id, actor=actor) + + return agent_state + @trace_method @enforce_types def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index cb0bc7dc..491e4bca 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -1,7 +1,7 @@ import asyncio import importlib import warnings -from typing import List, Optional, Union +from typing import List, Optional, Set, Union from letta.constants import ( BASE_FUNCTION_RETURN_CHAR_LIMIT, @@ -434,66 +434,59 @@ class ToolManager: @enforce_types @trace_method - async def upsert_base_tools_async(self, actor: PydanticUser) -> List[PydanticTool]: - """Add default tools in base.py and multi_agent.py""" + async def upsert_base_tools_async( + self, + actor: PydanticUser, + allowed_types: Optional[Set[ToolType]] = None, + ) -> List[PydanticTool]: + """Add default tools defined in the various function_sets modules, optionally filtered by ToolType.""" + functions_to_schema = {} for module_name in LETTA_TOOL_MODULE_NAMES: try: module = importlib.import_module(module_name) - except Exception as e: - # Handle other general exceptions - raise e - - try: - # Load the function set functions_to_schema.update(load_function_set(module)) except ValueError as e: - err = f"Error loading function set '{module_name}': {e}" - warnings.warn(err) + warnings.warn(f"Error loading function set '{module_name}': {e}") + except Exception as e: + raise e - # create tool in db tools = [] for name, schema in functions_to_schema.items(): - if name in LETTA_TOOL_SET: - if name in BASE_TOOLS: - tool_type = ToolType.LETTA_CORE - tags = [tool_type.value] - elif name in BASE_MEMORY_TOOLS: - tool_type = ToolType.LETTA_MEMORY_CORE - tags = [tool_type.value] - elif name in MULTI_AGENT_TOOLS: - tool_type = ToolType.LETTA_MULTI_AGENT_CORE - tags = [tool_type.value] - elif name in BASE_SLEEPTIME_TOOLS: - tool_type = ToolType.LETTA_SLEEPTIME_CORE - tags = [tool_type.value] - elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS: - tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE - tags = [tool_type.value] - elif name in BUILTIN_TOOLS: - tool_type = ToolType.LETTA_BUILTIN - tags = [tool_type.value] - elif name in FILES_TOOLS: - tool_type = ToolType.LETTA_FILES_CORE - tags = [tool_type.value] - else: - raise ValueError( - f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}" - ) + if name not in LETTA_TOOL_SET: + continue - # create to tool - tools.append( - self.create_or_update_tool_async( - PydanticTool( - name=name, - tags=tags, - source_type="python", - tool_type=tool_type, - return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT, - ), - actor=actor, - ) + if name in BASE_TOOLS: + tool_type = ToolType.LETTA_CORE + elif name in BASE_MEMORY_TOOLS: + tool_type = ToolType.LETTA_MEMORY_CORE + elif name in BASE_SLEEPTIME_TOOLS: + tool_type = ToolType.LETTA_SLEEPTIME_CORE + elif name in MULTI_AGENT_TOOLS: + tool_type = ToolType.LETTA_MULTI_AGENT_CORE + elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS: + tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE + elif name in BUILTIN_TOOLS: + tool_type = ToolType.LETTA_BUILTIN + elif name in FILES_TOOLS: + tool_type = ToolType.LETTA_FILES_CORE + else: + raise ValueError(f"Tool name {name} is not recognized in any known base tool set.") + + if allowed_types is not None and tool_type not in allowed_types: + continue + + tools.append( + self.create_or_update_tool_async( + PydanticTool( + name=name, + tags=[tool_type.value], + source_type="python", + tool_type=tool_type, + return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT, + ), + actor=actor, ) + ) - # TODO: Delete any base tools that are stale return await asyncio.gather(*tools) diff --git a/tests/test_client.py b/tests/test_client.py index e2e691eb..c957a5b5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -347,48 +347,6 @@ def test_update_agent_memory_limit(client: Letta): # -------------------------------------------------------------------------------------------------------------------- # Agent Tools # -------------------------------------------------------------------------------------------------------------------- -def test_function_return_limit(client: Letta): - """Test to see if the function return limit works""" - - def big_return(): - """ - Always call this tool. - - Returns: - important_data (str): Important data - """ - return "x" * 100000 - - padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50 - tool = client.tools.upsert_from_function(func=big_return, return_char_limit=1000) - agent = client.agents.create( - model="letta/letta-free", - embedding="letta/letta-free", - tool_ids=[tool.id], - ) - # get function response - response = client.agents.messages.create( - agent_id=agent.id, messages=[MessageCreate(role="user", content="call the big_return function")] - ) - print(response.messages) - - response_message = None - for message in response.messages: - if message.message_type == "tool_return_message": - response_message = message - break - - assert response_message, "ToolReturnMessage message not found in response" - res = response_message.tool_return - assert "function output was truncated " in res - - # TODO: Re-enable later - # res_json = json.loads(res) - # assert ( - # len(res_json["message"]) <= 1000 + padding - # ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}" - - client.agents.delete(agent_id=agent.id) def test_function_always_error(client: Letta): @@ -495,14 +453,6 @@ def test_messages(client: Letta, agent: AgentState): assert len(messages_response) > 0, "Retrieving messages failed" -def test_send_system_message(client: Letta, agent: AgentState): - """Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)""" - send_system_message_response = client.agents.messages.create( - agent_id=agent.id, messages=[MessageCreate(role="system", content="Event occurred: The user just logged off.")] - ) - assert send_system_message_response, "Sending message failed" - - # TODO: Add back when new agent loop hits # @pytest.mark.asyncio # async def test_send_message_parallel(client: Letta, agent: AgentState, request): diff --git a/tests/test_managers.py b/tests/test_managers.py index f861be83..0cbcfcc6 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2550,6 +2550,45 @@ async def test_upsert_base_tools(server: SyncServer, default_user, event_loop): assert t.json_schema +@pytest.mark.asyncio +@pytest.mark.parametrize( + "tool_type,expected_names", + [ + (ToolType.LETTA_CORE, BASE_TOOLS), + (ToolType.LETTA_MEMORY_CORE, BASE_MEMORY_TOOLS), + (ToolType.LETTA_MULTI_AGENT_CORE, MULTI_AGENT_TOOLS), + (ToolType.LETTA_SLEEPTIME_CORE, BASE_SLEEPTIME_TOOLS), + (ToolType.LETTA_VOICE_SLEEPTIME_CORE, sorted(set(BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS) - {"send_message"})), + (ToolType.LETTA_BUILTIN, BUILTIN_TOOLS), + (ToolType.LETTA_FILES_CORE, FILES_TOOLS), + ], +) +async def test_upsert_filtered_base_tools(server: SyncServer, default_user, tool_type, expected_names): + tools = await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={tool_type}) + tool_names = sorted([t.name for t in tools]) + expected_sorted = sorted(expected_names) + + assert tool_names == expected_sorted + assert all(t.tool_type == tool_type for t in tools) + + +@pytest.mark.asyncio +async def test_upsert_multiple_tool_types(server: SyncServer, default_user): + allowed = {ToolType.LETTA_CORE, ToolType.LETTA_BUILTIN, ToolType.LETTA_FILES_CORE} + tools = await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types=allowed) + tool_names = {t.name for t in tools} + expected = set(BASE_TOOLS + BUILTIN_TOOLS + FILES_TOOLS) + + assert tool_names == expected + assert all(t.tool_type in allowed for t in tools) + + +@pytest.mark.asyncio +async def test_upsert_base_tools_with_empty_type_filter(server: SyncServer, default_user): + tools = await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types=set()) + assert tools == [] + + # ====================================================================================================================== # Message Manager Tests # ====================================================================================================================== diff --git a/tests/test_sources.py b/tests/test_sources.py index ec8fd3cf..3cd5e671 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -9,6 +9,8 @@ from letta_client import CreateBlock from letta_client import Letta as LettaSDKClient from letta_client.types import AgentState +from letta.constants import FILES_TOOLS +from letta.orm.enums import ToolType from letta.schemas.message import MessageCreate from tests.helpers.utils import retry_until_success from tests.utils import wait_for_server @@ -17,6 +19,17 @@ from tests.utils import wait_for_server SERVER_PORT = 8283 +@pytest.fixture(autouse=True) +def clear_sources_jobs(client: LettaSDKClient): + # Clear existing sources + for source in client.sources.list(): + client.sources.delete(source_id=source.id) + + # Clear existing jobs + for job in client.jobs.list(): + client.jobs.delete(job_id=job.id) + + def run_server(): load_dotenv() @@ -61,6 +74,61 @@ def agent_state(client: LettaSDKClient): yield agent_state +# Tests + + +def test_auto_attach_detach_files_tools(client: LettaSDKClient): + """Test automatic attachment and detachment of file tools when managing agent sources.""" + # Create agent with basic configuration + agent = client.agents.create( + memory_blocks=[ + CreateBlock(label="human", value="username: sarah"), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + ) + + # Helper function to get file tools from agent + def get_file_tools(agent_state): + return {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} + + # Helper function to assert file tools presence + def assert_file_tools_present(agent_state, expected_tools): + actual_tools = get_file_tools(agent_state) + assert actual_tools == expected_tools, f"File tools mismatch.\nExpected: {expected_tools}\nFound: {actual_tools}" + + # Helper function to assert no file tools + def assert_no_file_tools(agent_state): + has_file_tools = any(tool.tool_type == ToolType.LETTA_FILES_CORE for tool in agent_state.tools) + assert not has_file_tools, "File tools should not be present" + + # Initial state: no file tools + assert_no_file_tools(agent) + + # Create and attach first source + source_1 = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + assert len(client.sources.list()) == 1 + + agent = client.agents.sources.attach(source_id=source_1.id, agent_id=agent.id) + assert_file_tools_present(agent, set(FILES_TOOLS)) + + # Create and attach second source + source_2 = client.sources.create(name="another_test_source", embedding="openai/text-embedding-ada-002") + assert len(client.sources.list()) == 2 + + agent = client.agents.sources.attach(source_id=source_2.id, agent_id=agent.id) + # File tools should remain after attaching second source + assert_file_tools_present(agent, set(FILES_TOOLS)) + + # Detach second source - tools should remain (first source still attached) + agent = client.agents.sources.detach(source_id=source_2.id, agent_id=agent.id) + assert_file_tools_present(agent, set(FILES_TOOLS)) + + # Detach first source - all file tools should be removed + agent = client.agents.sources.detach(source_id=source_1.id, agent_id=agent.id) + assert_no_file_tools(agent) + + @pytest.mark.parametrize( "file_path, expected_value, expected_label_regex", [ @@ -78,14 +146,6 @@ def test_file_upload_creates_source_blocks_correctly( expected_value: str, expected_label_regex: str, ): - # Clear existing sources - for source in client.sources.list(): - client.sources.delete(source_id=source.id) - - # Clear existing jobs - for job in client.jobs.list(): - client.jobs.delete(job_id=job.id) - # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") assert len(client.sources.list()) == 1 @@ -127,14 +187,6 @@ def test_file_upload_creates_source_blocks_correctly( def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): - # Clear existing sources - for source in client.sources.list(): - client.sources.delete(source_id=source.id) - - # Clear existing jobs - for job in client.jobs.list(): - client.jobs.delete(job_id=job.id) - # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") assert len(client.sources.list()) == 1 @@ -179,14 +231,6 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): - # Clear existing sources - for source in client.sources.list(): - client.sources.delete(source_id=source.id) - - # Clear existing jobs - for job in client.jobs.list(): - client.jobs.delete(job_id=job.id) - # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") assert len(client.sources.list()) == 1 @@ -227,22 +271,6 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a @retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_state: AgentState): - print(f"Starting test with agent ID: {agent_state.id}") - - # Clear existing sources - existing_sources = client.sources.list() - print(f"Found {len(existing_sources)} existing sources, clearing...") - for source in existing_sources: - print(f" Deleting source: {source.id}") - client.sources.delete(source_id=source.id) - - # Clear existing jobs - existing_jobs = client.jobs.list() - print(f"Found {len(existing_jobs)} existing jobs, clearing...") - for job in existing_jobs: - print(f" Deleting job: {job.id}") - client.jobs.delete(job_id=job.id) - # Create a new source print("Creating new source...") source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") @@ -372,22 +400,6 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat @retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: AgentState): - print(f"Starting test with agent ID: {agent_state.id}") - - # Clear existing sources - existing_sources = client.sources.list() - print(f"Found {len(existing_sources)} existing sources, clearing...") - for source in existing_sources: - print(f" Deleting source: {source.id}") - client.sources.delete(source_id=source.id) - - # Clear existing jobs - existing_jobs = client.jobs.list() - print(f"Found {len(existing_jobs)} existing jobs, clearing...") - for job in existing_jobs: - print(f" Deleting job: {job.id}") - client.jobs.delete(job_id=job.id) - # Create a new source print("Creating new source...") source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")