fix: update default include_base_tool_rules to None (#3762)
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -186,8 +186,8 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
include_multi_agent_tools: bool = Field(
|
||||
False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)."
|
||||
)
|
||||
include_base_tool_rules: bool = Field(
|
||||
True, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)."
|
||||
include_base_tool_rules: Optional[bool] = Field(
|
||||
None, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)."
|
||||
)
|
||||
include_default_source: bool = Field(
|
||||
False, description="If true, automatically creates and attaches a default data source for this agent."
|
||||
|
||||
@@ -334,13 +334,20 @@ class AgentManager:
|
||||
|
||||
tool_rules = list(agent_create.tool_rules or [])
|
||||
|
||||
# Override include_base_tool_rules to True if provider is not in excluded set
|
||||
if agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES:
|
||||
# Override include_base_tool_rules to False if provider is not in excluded set and include_base_tool_rules is not explicitly set to True
|
||||
if (
|
||||
(
|
||||
agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES
|
||||
and agent_create.include_base_tool_rules is None
|
||||
)
|
||||
and agent_create.agent_type != AgentType.sleeptime_agent
|
||||
) or agent_create.include_base_tool_rules is False:
|
||||
agent_create.include_base_tool_rules = False
|
||||
logger.info(f"Overriding include_base_tool_rules to True for provider: {agent_create.llm_config.model_endpoint_type}")
|
||||
logger.info(f"Overriding include_base_tool_rules to False for provider: {agent_create.llm_config.model_endpoint_type}")
|
||||
else:
|
||||
agent_create.include_base_tool_rules = True
|
||||
|
||||
should_add_base_tool_rules = agent_create.include_base_tool_rules
|
||||
|
||||
if should_add_base_tool_rules:
|
||||
for tn in tool_names:
|
||||
if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}:
|
||||
@@ -534,16 +541,22 @@ class AgentManager:
|
||||
|
||||
tool_ids = set(name_to_id.values()) | set(id_to_name.keys())
|
||||
tool_names = set(name_to_id.keys()) # now canonical
|
||||
|
||||
tool_rules = list(agent_create.tool_rules or [])
|
||||
|
||||
# Override include_base_tool_rules to True if provider is not in excluded set
|
||||
if agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES:
|
||||
# Override include_base_tool_rules to False if provider is not in excluded set and include_base_tool_rules is not explicitly set to True
|
||||
if (
|
||||
(
|
||||
agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES
|
||||
and agent_create.include_base_tool_rules is None
|
||||
)
|
||||
and agent_create.agent_type != AgentType.sleeptime_agent
|
||||
) or agent_create.include_base_tool_rules is False:
|
||||
agent_create.include_base_tool_rules = False
|
||||
logger.info(f"Overriding include_base_tool_rules to False for provider: {agent_create.llm_config.model_endpoint_type}")
|
||||
else:
|
||||
agent_create.include_base_tool_rules = True
|
||||
|
||||
should_add_base_tool_rules = agent_create.include_base_tool_rules
|
||||
|
||||
if should_add_base_tool_rules:
|
||||
for tn in tool_names:
|
||||
if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}:
|
||||
|
||||
@@ -151,6 +151,8 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
|
||||
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
|
||||
|
||||
# Assert tool rules
|
||||
print("TOOLRULES", request.tool_rules)
|
||||
print("AGENTTOOLRULES", agent.tool_rules)
|
||||
if request.tool_rules:
|
||||
assert len(agent.tool_rules) == len(
|
||||
request.tool_rules
|
||||
|
||||
@@ -562,6 +562,7 @@ async def test_continue_tool_rule(server, default_user):
|
||||
include_base_tools=False,
|
||||
include_base_tool_rules=False,
|
||||
)
|
||||
print(agent_state)
|
||||
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
@@ -569,6 +570,7 @@ async def test_continue_tool_rule(server, default_user):
|
||||
input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")],
|
||||
actor=default_user,
|
||||
)
|
||||
print(response)
|
||||
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
assert_invoked_function_call(response.messages, "core_memory_append")
|
||||
|
||||
@@ -192,6 +192,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
enable_sleeptime=True,
|
||||
include_base_tool_rules=True,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -1288,6 +1288,8 @@ class TestAgentFileRoundTrip:
|
||||
result = await agent_serialization_manager.import_file(original_export, other_user)
|
||||
imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0")
|
||||
second_export = await agent_serialization_manager.export([imported_agent_id], other_user)
|
||||
print(original_export.agents[0].tool_rules)
|
||||
print(second_export.agents[0].tool_rules)
|
||||
assert compare_agent_files(original_export, second_export)
|
||||
|
||||
async def test_multiple_roundtrips(self, server, agent_serialization_manager, test_agent, default_user, other_user):
|
||||
|
||||
@@ -817,7 +817,7 @@ async def test_create_agent_base_tool_rules_excluded_providers(server: SyncServe
|
||||
memory_blocks=memory_blocks,
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"), # This has model_endpoint_type="openai"
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tool_rules=True, # Should be overridden to False
|
||||
include_base_tool_rules=False,
|
||||
)
|
||||
|
||||
# Create the agent
|
||||
@@ -827,6 +827,7 @@ async def test_create_agent_base_tool_rules_excluded_providers(server: SyncServe
|
||||
)
|
||||
|
||||
# Assert that no base tool rules were added (since include_base_tool_rules was overridden to False)
|
||||
print(created_agent.tool_rules)
|
||||
assert created_agent.tool_rules is None or len(created_agent.tool_rules) == 0
|
||||
|
||||
|
||||
|
||||
@@ -1022,7 +1022,6 @@ def test_preview_payload(client: LettaSDKClient):
|
||||
assert system_message["role"] == "system"
|
||||
assert "base_instructions" in system_message["content"]
|
||||
assert "memory_blocks" in system_message["content"]
|
||||
assert "tool_usage_rules" in system_message["content"]
|
||||
assert "Letta" in system_message["content"]
|
||||
|
||||
assert isinstance(payload["tools"], list)
|
||||
|
||||
@@ -983,6 +983,7 @@ def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_m
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
include_base_tools=False,
|
||||
include_base_tool_rules=True,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user