diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index bc05fe5f..fd3afcfd 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -115,9 +115,10 @@ class AgentManager: block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump(to_orm=True)), actor=actor) block_ids.append(block.id) - # TODO: Remove this block once we deprecate the legacy `tools` field - # create passed in `tools` - tool_names = [] + # add passed in `tools` + tool_names = agent_create.tools or [] + + # add base tools if agent_create.include_base_tools: if agent_create.agent_type == AgentType.sleeptime_agent: tool_names.extend(BASE_SLEEPTIME_TOOLS) @@ -128,42 +129,45 @@ class AgentManager: tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS) if agent_create.include_multi_agent_tools: tool_names.extend(MULTI_AGENT_TOOLS) - if agent_create.tools: - tool_names.extend(agent_create.tools) - # Remove duplicates + + # remove duplicates tool_names = list(set(tool_names)) - # add default tool rules - if agent_create.include_base_tool_rules: - if not agent_create.tool_rules: - tool_rules = [] - else: - tool_rules = agent_create.tool_rules + # convert tool names to ids + tool_ids = [] + for tool_name in tool_names: + tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + tool_ids.append(tool.id) + # add passed in `tool_ids` + for tool_id in agent_create.tool_ids or []: + if tool_id not in tool_ids: + tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor) + if tool: + tool_ids.append(tool.id) + tool_names.append(tool.name) + else: + raise ValueError(f"Tool {tool_id} not found") + + # add default tool rules + tool_rules = agent_create.tool_rules or [] + if agent_create.include_base_tool_rules: # apply default tool rules for tool_name in tool_names: if tool_name == "send_message" or tool_name == "send_message_to_agent_async" or tool_name == "finish_rethinking_memory": tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name)) - elif tool_name in BASE_TOOLS: + elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS: tool_rules.append(PydanticContinueToolRule(tool_name=tool_name)) if agent_create.agent_type == AgentType.sleeptime_agent: tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"])) - else: - tool_rules = agent_create.tool_rules - # Check tool rules are valid + # if custom rules, check tool rules are valid if agent_create.tool_rules: check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules) - tool_ids = agent_create.tool_ids or [] - for tool_name in tool_names: - tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) - if tool: - tool_ids.append(tool.id) - # Remove duplicates - tool_ids = list(set(tool_ids)) - # Create the agent agent_state = self._create_agent( name=agent_create.name, diff --git a/tests/test_managers.py b/tests/test_managers.py index fb2f1c54..49373f5d 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -437,6 +437,7 @@ def sarah_agent(server: SyncServer, default_user, default_organization): memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, ), actor=default_user, ) @@ -452,6 +453,7 @@ def charles_agent(server: SyncServer, default_user, default_organization): memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, ), actor=default_user, ) @@ -476,6 +478,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")], tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"}, message_buffer_autoclear=True, + include_base_tools=False, ) created_agent = server.agent_manager.create_agent( create_agent_request, @@ -549,6 +552,7 @@ def agent_with_tags(server: SyncServer, default_user): llm_config=LLMConfig.default_config("gpt-4o-mini"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -560,6 +564,7 @@ def agent_with_tags(server: SyncServer, default_user): llm_config=LLMConfig.default_config("gpt-4o-mini"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -571,6 +576,7 @@ def agent_with_tags(server: SyncServer, default_user): llm_config=LLMConfig.default_config("gpt-4o-mini"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -672,6 +678,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use tags=["a", "b"], description="test_description", initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")], + include_base_tools=False, ) agent_state = server.agent_manager.create_agent( create_agent_request, @@ -697,6 +704,7 @@ def test_create_agent_default_initial_message(server: SyncServer, default_user, block_ids=[default_block.id], tags=["a", "b"], description="test_description", + include_base_tools=False, ) agent_state = server.agent_manager.create_agent( create_agent_request, @@ -841,6 +849,7 @@ def test_list_agents_ascending(server: SyncServer, default_user): llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -854,6 +863,7 @@ def test_list_agents_ascending(server: SyncServer, default_user): llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -871,6 +881,7 @@ def test_list_agents_descending(server: SyncServer, default_user): llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -884,6 +895,7 @@ def test_list_agents_descending(server: SyncServer, default_user): llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -905,6 +917,7 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, ), actor=default_user, ) @@ -1266,6 +1279,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -1281,6 +1295,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), memory_blocks=[], + include_base_tools=False, ), actor=default_user, ) @@ -1321,6 +1336,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def description="This is a search agent for testing", llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, ), actor=default_user, ) @@ -1332,6 +1348,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def description="Another search agent for testing", llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, ), actor=default_user, ) @@ -1343,6 +1360,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def description="This is a different agent", llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, ), actor=default_user, ) @@ -3351,6 +3369,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_ llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), identity_ids=[identity.id], + include_base_tools=False, ), actor=default_user, ) @@ -3359,6 +3378,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_ memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, ), actor=default_user, ) @@ -4643,6 +4663,7 @@ def test_list_tags(server: SyncServer, default_user, default_organization): llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), tags=tags[i : i + 3], # Each agent gets 3 consecutive tags + include_base_tools=False, ), ) agents.append(agent) diff --git a/tests/test_server.py b/tests/test_server.py index ec79fed5..51bd5816 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1116,7 +1116,47 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot assert any("Anna".lower() in passage.text.lower() for passage in passages2) -def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools): +def test_add_nonexisting_tool(server: SyncServer, user_id: str, base_tools): + actor = server.user_manager.get_user_or_default(user_id) + + # create agent + with pytest.raises(ValueError, match="not found"): + agent_state = server.create_agent( + request=CreateAgent( + name="memory_rebuild_test_agent", + tools=["fake_nonexisting_tool"], + memory_blocks=[ + CreateBlock(label="human", value="The human's name is Bob."), + CreateBlock(label="persona", value="My name is Alice."), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + include_base_tools=True, + ), + actor=actor, + ) + + +def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_memory_tools): + actor = server.user_manager.get_user_or_default(user_id) + + # create agent + agent_state = server.create_agent( + request=CreateAgent( + name="tool_rules_test_agent", + tool_ids=[t.id for t in base_tools + base_memory_tools], + memory_blocks=[], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + include_base_tools=False, + ), + actor=actor, + ) + + assert len(agent_state.tool_rules) == len(base_tools + base_memory_tools) + + +def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools, base_memory_tools): """Test that the memory rebuild is generating the correct number of role=system messages""" actor = server.user_manager.get_user_or_default(user_id)