feat: more robust tools setup in agent creation (#1605)

This commit is contained in:
cthomas
2025-04-07 20:15:16 -07:00
committed by GitHub
parent 148727a44a
commit aef866ef3d
3 changed files with 90 additions and 25 deletions

View File

@@ -115,9 +115,10 @@ class AgentManager:
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump(to_orm=True)), actor=actor)
block_ids.append(block.id)
# TODO: Remove this block once we deprecate the legacy `tools` field
# create passed in `tools`
tool_names = []
# add passed in `tools`
tool_names = agent_create.tools or []
# add base tools
if agent_create.include_base_tools:
if agent_create.agent_type == AgentType.sleeptime_agent:
tool_names.extend(BASE_SLEEPTIME_TOOLS)
@@ -128,42 +129,45 @@ class AgentManager:
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
if agent_create.include_multi_agent_tools:
tool_names.extend(MULTI_AGENT_TOOLS)
if agent_create.tools:
tool_names.extend(agent_create.tools)
# Remove duplicates
# remove duplicates
tool_names = list(set(tool_names))
# add default tool rules
if agent_create.include_base_tool_rules:
if not agent_create.tool_rules:
tool_rules = []
else:
tool_rules = agent_create.tool_rules
# convert tool names to ids
tool_ids = []
for tool_name in tool_names:
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
tool_ids.append(tool.id)
# add passed in `tool_ids`
for tool_id in agent_create.tool_ids or []:
if tool_id not in tool_ids:
tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
if tool:
tool_ids.append(tool.id)
tool_names.append(tool.name)
else:
raise ValueError(f"Tool {tool_id} not found")
# add default tool rules
tool_rules = agent_create.tool_rules or []
if agent_create.include_base_tool_rules:
# apply default tool rules
for tool_name in tool_names:
if tool_name == "send_message" or tool_name == "send_message_to_agent_async" or tool_name == "finish_rethinking_memory":
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
elif tool_name in BASE_TOOLS:
elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS:
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
if agent_create.agent_type == AgentType.sleeptime_agent:
tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
else:
tool_rules = agent_create.tool_rules
# Check tool rules are valid
# if custom rules, check tool rules are valid
if agent_create.tool_rules:
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
tool_ids = agent_create.tool_ids or []
for tool_name in tool_names:
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
if tool:
tool_ids.append(tool.id)
# Remove duplicates
tool_ids = list(set(tool_ids))
# Create the agent
agent_state = self._create_agent(
name=agent_create.name,

View File

@@ -437,6 +437,7 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
@@ -452,6 +453,7 @@ def charles_agent(server: SyncServer, default_user, default_organization):
memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
@@ -476,6 +478,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
message_buffer_autoclear=True,
include_base_tools=False,
)
created_agent = server.agent_manager.create_agent(
create_agent_request,
@@ -549,6 +552,7 @@ def agent_with_tags(server: SyncServer, default_user):
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -560,6 +564,7 @@ def agent_with_tags(server: SyncServer, default_user):
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -571,6 +576,7 @@ def agent_with_tags(server: SyncServer, default_user):
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -672,6 +678,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
tags=["a", "b"],
description="test_description",
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
include_base_tools=False,
)
agent_state = server.agent_manager.create_agent(
create_agent_request,
@@ -697,6 +704,7 @@ def test_create_agent_default_initial_message(server: SyncServer, default_user,
block_ids=[default_block.id],
tags=["a", "b"],
description="test_description",
include_base_tools=False,
)
agent_state = server.agent_manager.create_agent(
create_agent_request,
@@ -841,6 +849,7 @@ def test_list_agents_ascending(server: SyncServer, default_user):
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -854,6 +863,7 @@ def test_list_agents_ascending(server: SyncServer, default_user):
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -871,6 +881,7 @@ def test_list_agents_descending(server: SyncServer, default_user):
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -884,6 +895,7 @@ def test_list_agents_descending(server: SyncServer, default_user):
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -905,6 +917,7 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user):
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
@@ -1266,6 +1279,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -1281,6 +1295,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
include_base_tools=False,
),
actor=default_user,
)
@@ -1321,6 +1336,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
description="This is a search agent for testing",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
@@ -1332,6 +1348,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
description="Another search agent for testing",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
@@ -1343,6 +1360,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
description="This is a different agent",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
@@ -3351,6 +3369,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
identity_ids=[identity.id],
include_base_tools=False,
),
actor=default_user,
)
@@ -3359,6 +3378,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
include_base_tools=False,
),
actor=default_user,
)
@@ -4643,6 +4663,7 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tags=tags[i : i + 3], # Each agent gets 3 consecutive tags
include_base_tools=False,
),
)
agents.append(agent)

View File

@@ -1116,7 +1116,47 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools):
def test_add_nonexisting_tool(server: SyncServer, user_id: str, base_tools):
actor = server.user_manager.get_user_or_default(user_id)
# create agent
with pytest.raises(ValueError, match="not found"):
agent_state = server.create_agent(
request=CreateAgent(
name="memory_rebuild_test_agent",
tools=["fake_nonexisting_tool"],
memory_blocks=[
CreateBlock(label="human", value="The human's name is Bob."),
CreateBlock(label="persona", value="My name is Alice."),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
include_base_tools=True,
),
actor=actor,
)
def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_memory_tools):
actor = server.user_manager.get_user_or_default(user_id)
# create agent
agent_state = server.create_agent(
request=CreateAgent(
name="tool_rules_test_agent",
tool_ids=[t.id for t in base_tools + base_memory_tools],
memory_blocks=[],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
include_base_tools=False,
),
actor=actor,
)
assert len(agent_state.tool_rules) == len(base_tools + base_memory_tools)
def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)