feat: De-dupe tool rules [LET-4091] (#4282)
* Add hash/eqs for de-dupe * Add sdk test
This commit is contained in:
@@ -9,9 +9,9 @@ from typing import List, Type
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import ContinueToolRule, CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import LettaRequest, MessageCreate, TerminalToolRule, TextContent
|
||||
from letta_client import LettaRequest, MaxCountPerStepToolRule, MessageCreate, TerminalToolRule, TextContent
|
||||
from letta_client.client import BaseTool
|
||||
from letta_client.core import ApiError
|
||||
from letta_client.types import AgentState, ToolReturnMessage
|
||||
@@ -1251,6 +1251,59 @@ def test_agent_tools_list(client: LettaSDKClient):
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
def test_agent_tool_rules_deduplication(client: LettaSDKClient):
|
||||
"""Test that duplicate tool rules are properly deduplicated when creating/updating agents."""
|
||||
# Create agent with duplicate tool rules
|
||||
duplicate_rules = [
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
TerminalToolRule(tool_name="send_message"), # exact duplicate
|
||||
TerminalToolRule(tool_name="send_message"), # another duplicate
|
||||
]
|
||||
|
||||
agent_state = client.agents.create(
|
||||
name="test_agent_dedup",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a helpful assistant.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_rules=duplicate_rules,
|
||||
include_base_tools=False,
|
||||
)
|
||||
|
||||
# Get the agent and check tool rules
|
||||
retrieved_agent = client.agents.retrieve(agent_id=agent_state.id)
|
||||
assert len(retrieved_agent.tool_rules) == 1, f"Expected 1 unique tool rule, got {len(retrieved_agent.tool_rules)}"
|
||||
assert retrieved_agent.tool_rules[0].tool_name == "send_message"
|
||||
assert retrieved_agent.tool_rules[0].type == "exit_loop"
|
||||
|
||||
# Test update with duplicates
|
||||
update_rules = [
|
||||
ContinueToolRule(tool_name="conversation_search"),
|
||||
ContinueToolRule(tool_name="conversation_search"), # duplicate
|
||||
MaxCountPerStepToolRule(tool_name="test_tool", max_count_limit=2),
|
||||
MaxCountPerStepToolRule(tool_name="test_tool", max_count_limit=2), # exact duplicate
|
||||
MaxCountPerStepToolRule(tool_name="test_tool", max_count_limit=3), # different limit, not a duplicate
|
||||
]
|
||||
|
||||
updated_agent = client.agents.modify(agent_id=agent_state.id, tool_rules=update_rules)
|
||||
|
||||
# Check that duplicates were removed
|
||||
assert len(updated_agent.tool_rules) == 3, f"Expected 3 unique tool rules after update, got {len(updated_agent.tool_rules)}"
|
||||
|
||||
# Verify the specific rules
|
||||
rule_set = {(r.tool_name, r.type, getattr(r, "max_count_limit", None)) for r in updated_agent.tool_rules}
|
||||
expected_set = {
|
||||
("conversation_search", "continue_loop", None),
|
||||
("test_tool", "max_count_per_step", 2),
|
||||
("test_tool", "max_count_per_step", 3),
|
||||
}
|
||||
assert rule_set == expected_set, f"Tool rules don't match expected. Got: {rule_set}"
|
||||
|
||||
|
||||
def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient):
|
||||
"""Test adding a tool with multiple functions in the source code"""
|
||||
import textwrap
|
||||
|
||||
Reference in New Issue
Block a user