fix: update default include_base_tool_rules to None (#3762)

Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
This commit is contained in:
Kevin Lin
2025-08-06 15:58:29 -07:00
committed by GitHub
parent 76679e3ecc
commit e20f4eca92
9 changed files with 33 additions and 12 deletions

View File

@@ -186,8 +186,8 @@ class CreateAgent(BaseModel, validate_assignment=True): #
include_multi_agent_tools: bool = Field(
False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)."
)
include_base_tool_rules: bool = Field(
True, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)."
include_base_tool_rules: Optional[bool] = Field(
None, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)."
)
include_default_source: bool = Field(
False, description="If true, automatically creates and attaches a default data source for this agent."

View File

@@ -334,13 +334,20 @@ class AgentManager:
tool_rules = list(agent_create.tool_rules or [])
# Override include_base_tool_rules to True if provider is not in excluded set
if agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES:
# Override include_base_tool_rules to False if provider is not in excluded set and include_base_tool_rules is not explicitly set to True
if (
(
agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES
and agent_create.include_base_tool_rules is None
)
and agent_create.agent_type != AgentType.sleeptime_agent
) or agent_create.include_base_tool_rules is False:
agent_create.include_base_tool_rules = False
logger.info(f"Overriding include_base_tool_rules to True for provider: {agent_create.llm_config.model_endpoint_type}")
logger.info(f"Overriding include_base_tool_rules to False for provider: {agent_create.llm_config.model_endpoint_type}")
else:
agent_create.include_base_tool_rules = True
should_add_base_tool_rules = agent_create.include_base_tool_rules
if should_add_base_tool_rules:
for tn in tool_names:
if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}:
@@ -534,16 +541,22 @@ class AgentManager:
tool_ids = set(name_to_id.values()) | set(id_to_name.keys())
tool_names = set(name_to_id.keys()) # now canonical
tool_rules = list(agent_create.tool_rules or [])
# Override include_base_tool_rules to True if provider is not in excluded set
if agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES:
# Override include_base_tool_rules to False if provider is not in excluded set and include_base_tool_rules is not explicitly set to True
if (
(
agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES
and agent_create.include_base_tool_rules is None
)
and agent_create.agent_type != AgentType.sleeptime_agent
) or agent_create.include_base_tool_rules is False:
agent_create.include_base_tool_rules = False
logger.info(f"Overriding include_base_tool_rules to False for provider: {agent_create.llm_config.model_endpoint_type}")
else:
agent_create.include_base_tool_rules = True
should_add_base_tool_rules = agent_create.include_base_tool_rules
if should_add_base_tool_rules:
for tn in tool_names:
if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}:

View File

@@ -151,6 +151,8 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
# Assert tool rules
print("TOOLRULES", request.tool_rules)
print("AGENTTOOLRULES", agent.tool_rules)
if request.tool_rules:
assert len(agent.tool_rules) == len(
request.tool_rules

View File

@@ -562,6 +562,7 @@ async def test_continue_tool_rule(server, default_user):
include_base_tools=False,
include_base_tool_rules=False,
)
print(agent_state)
response = await run_agent_step(
server=server,
@@ -569,6 +570,7 @@ async def test_continue_tool_rule(server, default_user):
input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")],
actor=default_user,
)
print(response)
assert_invoked_function_call(response.messages, "send_message")
assert_invoked_function_call(response.messages, "core_memory_append")

View File

@@ -192,6 +192,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
model="anthropic/claude-3-5-sonnet-20240620",
embedding="openai/text-embedding-3-small",
enable_sleeptime=True,
include_base_tool_rules=True,
),
actor=actor,
)

View File

@@ -1288,6 +1288,8 @@ class TestAgentFileRoundTrip:
result = await agent_serialization_manager.import_file(original_export, other_user)
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
second_export = await agent_serialization_manager.export([imported_agent_id], other_user)
print(original_export.agents[0].tool_rules)
print(second_export.agents[0].tool_rules)
assert compare_agent_files(original_export, second_export)
async def test_multiple_roundtrips(self, server, agent_serialization_manager, test_agent, default_user, other_user):

View File

@@ -817,7 +817,7 @@ async def test_create_agent_base_tool_rules_excluded_providers(server: SyncServe
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4o-mini"), # This has model_endpoint_type="openai"
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tool_rules=True, # Should be overridden to False
include_base_tool_rules=False,
)
# Create the agent
@@ -827,6 +827,7 @@ async def test_create_agent_base_tool_rules_excluded_providers(server: SyncServe
)
# Assert that no base tool rules were added (since include_base_tool_rules was overridden to False)
print(created_agent.tool_rules)
assert created_agent.tool_rules is None or len(created_agent.tool_rules) == 0

View File

@@ -1022,7 +1022,6 @@ def test_preview_payload(client: LettaSDKClient):
assert system_message["role"] == "system"
assert "base_instructions" in system_message["content"]
assert "memory_blocks" in system_message["content"]
assert "tool_usage_rules" in system_message["content"]
assert "Letta" in system_message["content"]
assert isinstance(payload["tools"], list)

View File

@@ -983,6 +983,7 @@ def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_m
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
include_base_tools=False,
include_base_tool_rules=True,
),
actor=actor,
)