feat: more robust tools setup in agent creation (#1605)
This commit is contained in:
@@ -115,9 +115,10 @@ class AgentManager:
|
||||
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump(to_orm=True)), actor=actor)
|
||||
block_ids.append(block.id)
|
||||
|
||||
# TODO: Remove this block once we deprecate the legacy `tools` field
|
||||
# create passed in `tools`
|
||||
tool_names = []
|
||||
# add passed in `tools`
|
||||
tool_names = agent_create.tools or []
|
||||
|
||||
# add base tools
|
||||
if agent_create.include_base_tools:
|
||||
if agent_create.agent_type == AgentType.sleeptime_agent:
|
||||
tool_names.extend(BASE_SLEEPTIME_TOOLS)
|
||||
@@ -128,42 +129,45 @@ class AgentManager:
|
||||
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
if agent_create.include_multi_agent_tools:
|
||||
tool_names.extend(MULTI_AGENT_TOOLS)
|
||||
if agent_create.tools:
|
||||
tool_names.extend(agent_create.tools)
|
||||
# Remove duplicates
|
||||
|
||||
# remove duplicates
|
||||
tool_names = list(set(tool_names))
|
||||
|
||||
# add default tool rules
|
||||
if agent_create.include_base_tool_rules:
|
||||
if not agent_create.tool_rules:
|
||||
tool_rules = []
|
||||
else:
|
||||
tool_rules = agent_create.tool_rules
|
||||
# convert tool names to ids
|
||||
tool_ids = []
|
||||
for tool_name in tool_names:
|
||||
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
tool_ids.append(tool.id)
|
||||
|
||||
# add passed in `tool_ids`
|
||||
for tool_id in agent_create.tool_ids or []:
|
||||
if tool_id not in tool_ids:
|
||||
tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
if tool:
|
||||
tool_ids.append(tool.id)
|
||||
tool_names.append(tool.name)
|
||||
else:
|
||||
raise ValueError(f"Tool {tool_id} not found")
|
||||
|
||||
# add default tool rules
|
||||
tool_rules = agent_create.tool_rules or []
|
||||
if agent_create.include_base_tool_rules:
|
||||
# apply default tool rules
|
||||
for tool_name in tool_names:
|
||||
if tool_name == "send_message" or tool_name == "send_message_to_agent_async" or tool_name == "finish_rethinking_memory":
|
||||
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
|
||||
elif tool_name in BASE_TOOLS:
|
||||
elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS:
|
||||
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
|
||||
|
||||
if agent_create.agent_type == AgentType.sleeptime_agent:
|
||||
tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
|
||||
|
||||
else:
|
||||
tool_rules = agent_create.tool_rules
|
||||
# Check tool rules are valid
|
||||
# if custom rules, check tool rules are valid
|
||||
if agent_create.tool_rules:
|
||||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
|
||||
|
||||
tool_ids = agent_create.tool_ids or []
|
||||
for tool_name in tool_names:
|
||||
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
||||
if tool:
|
||||
tool_ids.append(tool.id)
|
||||
# Remove duplicates
|
||||
tool_ids = list(set(tool_ids))
|
||||
|
||||
# Create the agent
|
||||
agent_state = self._create_agent(
|
||||
name=agent_create.name,
|
||||
|
||||
@@ -437,6 +437,7 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -452,6 +453,7 @@ def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -476,6 +478,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
|
||||
message_buffer_autoclear=True,
|
||||
include_base_tools=False,
|
||||
)
|
||||
created_agent = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@@ -549,6 +552,7 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -560,6 +564,7 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -571,6 +576,7 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -672,6 +678,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
|
||||
tags=["a", "b"],
|
||||
description="test_description",
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||
include_base_tools=False,
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@@ -697,6 +704,7 @@ def test_create_agent_default_initial_message(server: SyncServer, default_user,
|
||||
block_ids=[default_block.id],
|
||||
tags=["a", "b"],
|
||||
description="test_description",
|
||||
include_base_tools=False,
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@@ -841,6 +849,7 @@ def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -854,6 +863,7 @@ def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -871,6 +881,7 @@ def test_list_agents_descending(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -884,6 +895,7 @@ def test_list_agents_descending(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -905,6 +917,7 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user):
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1266,6 +1279,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1281,6 +1295,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1321,6 +1336,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
description="This is a search agent for testing",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1332,6 +1348,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
description="Another search agent for testing",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1343,6 +1360,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
description="This is a different agent",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -3351,6 +3369,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
identity_ids=[identity.id],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -3359,6 +3378,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -4643,6 +4663,7 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=tags[i : i + 3], # Each agent gets 3 consecutive tags
|
||||
include_base_tools=False,
|
||||
),
|
||||
)
|
||||
agents.append(agent)
|
||||
|
||||
@@ -1116,7 +1116,47 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
|
||||
|
||||
|
||||
def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools):
|
||||
def test_add_nonexisting_tool(server: SyncServer, user_id: str, base_tools):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
# create agent
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="memory_rebuild_test_agent",
|
||||
tools=["fake_nonexisting_tool"],
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="The human's name is Bob."),
|
||||
CreateBlock(label="persona", value="My name is Alice."),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
include_base_tools=True,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_memory_tools):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="tool_rules_test_agent",
|
||||
tool_ids=[t.id for t in base_tools + base_memory_tools],
|
||||
memory_blocks=[],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
assert len(agent_state.tool_rules) == len(base_tools + base_memory_tools)
|
||||
|
||||
|
||||
def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools, base_memory_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user