From bba9753e1c46ae1283c5ba75a59bc86b8752c04c Mon Sep 17 00:00:00 2001 From: Kevin Lin Date: Wed, 6 Aug 2025 15:58:29 -0700 Subject: [PATCH] fix: update default `include_base_tool_rules` to None (#3762) Co-authored-by: Matthew Zhou --- letta/schemas/agent.py | 4 +-- letta/services/agent_manager.py | 29 ++++++++++++++++------ tests/helpers/utils.py | 2 ++ tests/integration_test_agent_tool_graph.py | 2 ++ tests/integration_test_sleeptime_agent.py | 1 + tests/test_agent_serialization_v2.py | 2 ++ tests/test_managers.py | 3 ++- tests/test_sdk_client.py | 1 - tests/test_server.py | 1 + 9 files changed, 33 insertions(+), 12 deletions(-) diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index b7cee629..0b654a1d 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -186,8 +186,8 @@ class CreateAgent(BaseModel, validate_assignment=True): # include_multi_agent_tools: bool = Field( False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)." ) - include_base_tool_rules: bool = Field( - True, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)." + include_base_tool_rules: Optional[bool] = Field( + None, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)." ) include_default_source: bool = Field( False, description="If true, automatically creates and attaches a default data source for this agent." diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index c3e57c5c..9ad5e41a 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -334,13 +334,20 @@ class AgentManager: tool_rules = list(agent_create.tool_rules or []) - # Override include_base_tool_rules to True if provider is not in excluded set - if agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES: + # Override include_base_tool_rules to False if provider is not in excluded set and include_base_tool_rules is not explicitly set to True + if ( + ( + agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES + and agent_create.include_base_tool_rules is None + ) + and agent_create.agent_type != AgentType.sleeptime_agent + ) or agent_create.include_base_tool_rules is False: agent_create.include_base_tool_rules = False - logger.info(f"Overriding include_base_tool_rules to True for provider: {agent_create.llm_config.model_endpoint_type}") + logger.info(f"Overriding include_base_tool_rules to False for provider: {agent_create.llm_config.model_endpoint_type}") + else: + agent_create.include_base_tool_rules = True should_add_base_tool_rules = agent_create.include_base_tool_rules - if should_add_base_tool_rules: for tn in tool_names: if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}: @@ -534,16 +541,22 @@ class AgentManager: tool_ids = set(name_to_id.values()) | set(id_to_name.keys()) tool_names = set(name_to_id.keys()) # now canonical - tool_rules = list(agent_create.tool_rules or []) - # Override include_base_tool_rules to True if provider is not in excluded set - if agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES: + # Override include_base_tool_rules to False if provider is not in excluded set and include_base_tool_rules is not explicitly set to True + if ( + ( + agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES + and agent_create.include_base_tool_rules is None + ) + and agent_create.agent_type != AgentType.sleeptime_agent + ) or agent_create.include_base_tool_rules is False: agent_create.include_base_tool_rules = False logger.info(f"Overriding include_base_tool_rules to False for provider: {agent_create.llm_config.model_endpoint_type}") + else: + agent_create.include_base_tool_rules = True should_add_base_tool_rules = agent_create.include_base_tool_rules - if should_add_base_tool_rules: for tn in tool_names: if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}: diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 3d74e430..36918df7 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -151,6 +151,8 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}" # Assert tool rules + print("TOOLRULES", request.tool_rules) + print("AGENTTOOLRULES", agent.tool_rules) if request.tool_rules: assert len(agent.tool_rules) == len( request.tool_rules diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 57445d68..a6964ed3 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -562,6 +562,7 @@ async def test_continue_tool_rule(server, default_user): include_base_tools=False, include_base_tool_rules=False, ) + print(agent_state) response = await run_agent_step( server=server, @@ -569,6 +570,7 @@ async def test_continue_tool_rule(server, default_user): input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")], actor=default_user, ) + print(response) assert_invoked_function_call(response.messages, "send_message") assert_invoked_function_call(response.messages, "core_memory_append") diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index cb2a66cd..e2190633 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -192,6 +192,7 @@ async def test_sleeptime_group_chat_v2(server, actor): model="anthropic/claude-3-5-sonnet-20240620", embedding="openai/text-embedding-3-small", enable_sleeptime=True, + include_base_tool_rules=True, ), actor=actor, ) diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 48dc7229..d17e490f 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -1288,6 +1288,8 @@ class TestAgentFileRoundTrip: result = await agent_serialization_manager.import_file(original_export, other_user) imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") second_export = await agent_serialization_manager.export([imported_agent_id], other_user) + print(original_export.agents[0].tool_rules) + print(second_export.agents[0].tool_rules) assert compare_agent_files(original_export, second_export) async def test_multiple_roundtrips(self, server, agent_serialization_manager, test_agent, default_user, other_user): diff --git a/tests/test_managers.py b/tests/test_managers.py index 5f6b0d1c..cc6ba588 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -817,7 +817,7 @@ async def test_create_agent_base_tool_rules_excluded_providers(server: SyncServe memory_blocks=memory_blocks, llm_config=LLMConfig.default_config("gpt-4o-mini"), # This has model_endpoint_type="openai" embedding_config=EmbeddingConfig.default_config(provider="openai"), - include_base_tool_rules=True, # Should be overridden to False + include_base_tool_rules=False, ) # Create the agent @@ -827,6 +827,7 @@ async def test_create_agent_base_tool_rules_excluded_providers(server: SyncServe ) # Assert that no base tool rules were added (since include_base_tool_rules was overridden to False) + print(created_agent.tool_rules) assert created_agent.tool_rules is None or len(created_agent.tool_rules) == 0 diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 335141bf..fddfafa7 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1022,7 +1022,6 @@ def test_preview_payload(client: LettaSDKClient): assert system_message["role"] == "system" assert "base_instructions" in system_message["content"] assert "memory_blocks" in system_message["content"] - assert "tool_usage_rules" in system_message["content"] assert "Letta" in system_message["content"] assert isinstance(payload["tools"], list) diff --git a/tests/test_server.py b/tests/test_server.py index dd0727ac..bc093c12 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -983,6 +983,7 @@ def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_m model="openai/gpt-4o-mini", embedding="openai/text-embedding-3-small", include_base_tools=False, + include_base_tool_rules=True, ), actor=actor, )