feat: Enable dynamic toggling of tool choice in v3 agent loop for OpenAI [LET-4564] (#5042)
* Add subsequent flag * Finish integrating constrained/unconstrained toggling on v3 agent loop * Update tests to run on v3 * Run lint
This commit is contained in:
committed by
Caren Thomas
parent
c465da27e6
commit
df5c997da0
@@ -51,6 +51,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
* Support Gemini / OpenAI client
|
||||
"""
|
||||
|
||||
def _initialize_state(self):
|
||||
super()._initialize_state()
|
||||
self._require_tool_call = False
|
||||
|
||||
@trace_method
|
||||
async def step(
|
||||
self,
|
||||
@@ -279,6 +283,15 @@ class LettaAgentV3(LettaAgentV2):
|
||||
try:
|
||||
self.last_function_response = _load_last_function_response(messages)
|
||||
valid_tools = await self._get_valid_tools()
|
||||
require_tool_call = self.tool_rules_solver.should_force_tool_call()
|
||||
|
||||
if self._require_tool_call != require_tool_call:
|
||||
if require_tool_call:
|
||||
self.logger.info("switching to constrained mode (forcing tool call)")
|
||||
else:
|
||||
self.logger.info("switching to unconstrained mode (allowing non-tool responses)")
|
||||
self._require_tool_call = require_tool_call
|
||||
|
||||
approval_request, approval_response = _maybe_get_approval_messages(messages)
|
||||
if approval_request and approval_response:
|
||||
tool_call = approval_request.tool_calls[0]
|
||||
@@ -307,6 +320,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
llm_config=self.agent_state.llm_config,
|
||||
tools=valid_tools,
|
||||
force_tool_call=force_tool_call,
|
||||
requires_subsequent_tool_call=self._require_tool_call,
|
||||
)
|
||||
if dry_run:
|
||||
yield request_data
|
||||
@@ -590,8 +604,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
elif tool_call is None:
|
||||
# TODO could just hardcode the line here instead of calling the function...
|
||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||
# agent_state=agent_state,
|
||||
# request_heartbeat=False,
|
||||
agent_state=agent_state,
|
||||
tool_call_name=None,
|
||||
tool_rule_violated=False,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
@@ -705,8 +718,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
# 4. Decide whether to keep stepping (focal section simplified)
|
||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||
# agent_state=agent_state,
|
||||
# request_heartbeat=request_heartbeat,
|
||||
agent_state=agent_state,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_rule_violated=tool_rule_violated,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
@@ -753,8 +765,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
@trace_method
|
||||
def _decide_continuation(
|
||||
self,
|
||||
# agent_state: AgentState,
|
||||
# request_heartbeat: bool,
|
||||
agent_state: AgentState,
|
||||
tool_call_name: Optional[str],
|
||||
tool_rule_violated: bool,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
@@ -771,19 +782,14 @@ class LettaAgentV3(LettaAgentV2):
|
||||
2c. Called tool + tool rule violation (did not execute)
|
||||
|
||||
"""
|
||||
continue_stepping = True # Default continue
|
||||
continuation_reason: str | None = None
|
||||
stop_reason: LettaStopReason | None = None
|
||||
|
||||
if tool_call_name is None:
|
||||
# No tool call? End loop
|
||||
return False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
|
||||
else:
|
||||
# If we have a tool call, we continue stepping
|
||||
return True, None, None
|
||||
|
||||
# TODO support tool rules
|
||||
# I think we can just uncomment the bellow?
|
||||
if tool_rule_violated:
|
||||
continue_stepping = True
|
||||
continuation_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
|
||||
@@ -791,8 +797,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
tool_rules_solver.register_tool_call(tool_call_name)
|
||||
|
||||
if tool_rules_solver.is_terminal_tool(tool_call_name):
|
||||
if continue_stepping:
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
|
||||
continue_stepping = False
|
||||
|
||||
elif tool_rules_solver.has_children_tools(tool_call_name):
|
||||
@@ -809,7 +814,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
||||
else:
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||||
if not continue_stepping and uncalled:
|
||||
if uncalled:
|
||||
continue_stepping = True
|
||||
continuation_reason = (
|
||||
f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [{', '.join(uncalled)}] to be called still."
|
||||
@@ -817,7 +822,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
stop_reason = None # reset – we’re still going
|
||||
|
||||
return continue_stepping, continuation_reason, stop_reason
|
||||
return continue_stepping, continuation_reason, stop_reason
|
||||
|
||||
@trace_method
|
||||
async def _get_valid_tools(self):
|
||||
|
||||
@@ -183,6 +183,7 @@ class AnthropicClient(LLMClientBase):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
||||
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise
|
||||
|
||||
@@ -70,8 +70,9 @@ class BedrockClient(AnthropicClient):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
# remove disallowed fields
|
||||
if "tool_choice" in data:
|
||||
del data["tool_choice"]["disable_parallel_tool_use"]
|
||||
|
||||
@@ -337,11 +337,12 @@ class DeepseekClient(OpenAIClient):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
|
||||
llm_config.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
def add_functions_to_system_message(system_message: ChatMessage):
|
||||
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in tools)} </available functions>"
|
||||
|
||||
@@ -280,6 +280,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
llm_config: LLMConfig,
|
||||
tools: List[dict],
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
|
||||
@@ -29,8 +29,9 @@ class GroqClient(OpenAIClient):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
# Groq validation - these fields are not supported and will cause 400 errors
|
||||
# https://console.groq.com/docs/openai
|
||||
|
||||
@@ -127,6 +127,7 @@ class LLMClientBase:
|
||||
llm_config: LLMConfig,
|
||||
tools: List[dict],
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
|
||||
@@ -206,6 +206,7 @@ class OpenAIClient(LLMClientBase):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for the OpenAI Responses API.
|
||||
@@ -224,14 +225,15 @@ class OpenAIClient(LLMClientBase):
|
||||
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
|
||||
model = None
|
||||
|
||||
# Default to auto, unless there's a forced tool call coming from above
|
||||
# Default to auto, unless there's a forced tool call coming from above or requires_subsequent_tool_call is True
|
||||
tool_choice = None
|
||||
if tools: # only set tool_choice if tools exist
|
||||
tool_choice = (
|
||||
"auto"
|
||||
if force_tool_call is None
|
||||
else ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
)
|
||||
if force_tool_call is not None:
|
||||
tool_choice = {"type": "function", "name": force_tool_call}
|
||||
elif requires_subsequent_tool_call:
|
||||
tool_choice = "required"
|
||||
else:
|
||||
tool_choice = "auto"
|
||||
|
||||
# Convert the tools from the ChatCompletions style to the Responses style
|
||||
if tools:
|
||||
@@ -352,6 +354,7 @@ class OpenAIClient(LLMClientBase):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for the OpenAI API.
|
||||
@@ -364,6 +367,7 @@ class OpenAIClient(LLMClientBase):
|
||||
llm_config=llm_config,
|
||||
tools=tools,
|
||||
force_tool_call=force_tool_call,
|
||||
requires_subsequent_tool_call=requires_subsequent_tool_call,
|
||||
)
|
||||
|
||||
if agent_type == AgentType.letta_v1_agent:
|
||||
@@ -407,15 +411,16 @@ class OpenAIClient(LLMClientBase):
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if tools: # only set tool_choice if tools exist
|
||||
if self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
|
||||
if force_tool_call is not None:
|
||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
elif requires_subsequent_tool_call:
|
||||
tool_choice = "required"
|
||||
elif self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
|
||||
tool_choice = "auto"
|
||||
else:
|
||||
# only set if tools is non-Null
|
||||
tool_choice = "required"
|
||||
|
||||
if force_tool_call is not None:
|
||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
|
||||
data = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=fill_image_content_in_messages(openai_message_list, messages),
|
||||
|
||||
@@ -29,8 +29,9 @@ class XAIClient(OpenAIClient):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
# Specific bug for the mini models (as of Apr 14, 2025)
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
||||
|
||||
7
tests/configs/llm_model_configs/openai-gpt-4.1.json
Normal file
7
tests/configs/llm_model_configs/openai-gpt-4.1.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"context_window": 1047576,
|
||||
"model": "gpt-4.1-2025-04-14",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.message import MessageCreate
|
||||
@@ -12,7 +13,6 @@ from letta.schemas.run import Run
|
||||
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager
|
||||
from tests.helpers.endpoints_helper import (
|
||||
assert_invoked_function_call,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
@@ -25,7 +25,9 @@ from tests.utils import create_tool_from_func
|
||||
# Generate uuid for agent name for this example
|
||||
namespace = uuid.NAMESPACE_DNS
|
||||
agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
OPENAI_CONFIG = "tests/configs/llm_model_configs/openai-gpt-4.1.json"
|
||||
CLAUDE_SONNET_CONFIG = "tests/configs/llm_model_configs/claude-4-sonnet.json"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -38,6 +40,12 @@ async def server():
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def default_config_file():
|
||||
"""Provides the default config file path for tests."""
|
||||
return OPENAI_CONFIG
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def first_secret_tool(server):
|
||||
def first_secret_word():
|
||||
@@ -241,7 +249,7 @@ async def default_user(server):
|
||||
|
||||
async def run_agent_step(agent_state, input_messages, actor):
|
||||
"""Helper function to run agent step using LettaAgent directly instead of server.send_messages."""
|
||||
agent_loop = LettaAgentV2(
|
||||
agent_loop = LettaAgentV3(
|
||||
agent_state=agent_state,
|
||||
actor=actor,
|
||||
)
|
||||
@@ -283,7 +291,7 @@ async def test_single_path_agent_tool_call_graph(
|
||||
]
|
||||
|
||||
# Make agent state
|
||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = await run_agent_step(
|
||||
agent_state=agent_state,
|
||||
input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")],
|
||||
@@ -321,9 +329,8 @@ async def test_single_path_agent_tool_call_graph(
|
||||
@pytest.mark.parametrize(
|
||||
"config_file",
|
||||
[
|
||||
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
||||
CLAUDE_SONNET_CONFIG,
|
||||
OPENAI_CONFIG,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("init_tools_case", ["single", "multiple"])
|
||||
@@ -383,13 +390,12 @@ async def test_claude_initial_tool_rule_enforced(
|
||||
TerminalToolRule(tool_name=second_secret_tool.name),
|
||||
]
|
||||
tools = [first_secret_tool, second_secret_tool]
|
||||
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
|
||||
for i in range(3):
|
||||
agent_uuid = str(uuid.uuid4())
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
anthropic_config_file,
|
||||
CLAUDE_SONNET_CONFIG,
|
||||
agent_uuid=agent_uuid,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules,
|
||||
@@ -426,8 +432,7 @@ async def test_claude_initial_tool_rule_enforced(
|
||||
@pytest.mark.parametrize(
|
||||
"config_file",
|
||||
[
|
||||
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
|
||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
||||
OPENAI_CONFIG,
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@@ -508,13 +513,12 @@ async def test_init_tool_rule_always_fails(
|
||||
include_base_tools,
|
||||
):
|
||||
"""Test behavior when InitToolRule invokes a tool that always fails."""
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_uuid = str(uuid.uuid4())
|
||||
|
||||
tool_rule = InitToolRule(tool_name=auto_error_tool.name)
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
config_file,
|
||||
OPENAI_CONFIG,
|
||||
agent_uuid=agent_uuid,
|
||||
tool_ids=[auto_error_tool.id],
|
||||
tool_rules=[tool_rule],
|
||||
@@ -535,7 +539,6 @@ async def test_init_tool_rule_always_fails(
|
||||
@pytest.mark.asyncio
|
||||
async def test_continue_tool_rule(server, default_user):
|
||||
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_uuid = str(uuid.uuid4())
|
||||
|
||||
tools = [
|
||||
@@ -551,7 +554,7 @@ async def test_continue_tool_rule(server, default_user):
|
||||
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
config_file,
|
||||
CLAUDE_SONNET_CONFIG,
|
||||
agent_uuid,
|
||||
tool_ids=tool_ids,
|
||||
tool_rules=tool_rules,
|
||||
@@ -618,7 +621,7 @@ async def test_continue_tool_rule(server, default_user):
|
||||
# ]
|
||||
# tools = [flip_coin_tool, reveal_secret]
|
||||
#
|
||||
# config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
# config_file = CLAUDE_SONNET_CONFIG
|
||||
# agent_state = await setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
# response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
|
||||
#
|
||||
@@ -824,7 +827,6 @@ async def test_continue_tool_rule(server, default_user):
|
||||
async def test_single_required_before_exit_tool(server, disable_e2b_api_key, save_data_tool, default_user):
|
||||
"""Test that agent is forced to call a single required-before-exit tool before ending."""
|
||||
agent_name = "required_exit_single_tool_agent"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tools and rules
|
||||
tools = [save_data_tool]
|
||||
@@ -835,7 +837,7 @@ async def test_single_required_before_exit_tool(server, disable_e2b_api_key, sav
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Send message that would normally cause exit
|
||||
response = await run_agent_step(
|
||||
@@ -866,7 +868,6 @@ async def test_single_required_before_exit_tool(server, disable_e2b_api_key, sav
|
||||
async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key, save_data_tool, cleanup_temp_files_tool, default_user):
|
||||
"""Test that agent calls all required-before-exit tools before ending."""
|
||||
agent_name = "required_exit_multi_tool_agent"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tools and rules
|
||||
tools = [save_data_tool, cleanup_temp_files_tool]
|
||||
@@ -878,7 +879,7 @@ async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key,
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Send message that would normally cause exit
|
||||
response = await run_agent_step(
|
||||
@@ -911,7 +912,6 @@ async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key,
|
||||
async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key, first_secret_tool, save_data_tool, default_user):
|
||||
"""Test required-before-exit rules work alongside other tool rules."""
|
||||
agent_name = "required_exit_with_rules_agent"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tools and rules - combine with child tool rules
|
||||
tools = [first_secret_tool, save_data_tool]
|
||||
@@ -923,7 +923,7 @@ async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Send message that would trigger tool flow
|
||||
response = await run_agent_step(
|
||||
@@ -956,7 +956,6 @@ async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key
|
||||
async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_key, save_data_tool, default_user):
|
||||
"""Test that agent can exit normally when required tools are called during regular operation."""
|
||||
agent_name = "required_exit_normal_flow_agent"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tools and rules
|
||||
tools = [save_data_tool]
|
||||
@@ -967,7 +966,7 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Send message that explicitly mentions calling the required tool
|
||||
response = await run_agent_step(
|
||||
@@ -990,51 +989,3 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
|
||||
assert len(send_message_calls) == 1, "Should call send_message exactly once"
|
||||
|
||||
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminal_tool_rule_send_message_request_heartbeat_false(server, disable_e2b_api_key, default_user):
|
||||
"""Test that when there's a terminal tool rule on send_message, the tool call has request_heartbeat=False."""
|
||||
agent_name = "terminal_send_message_heartbeat_test"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tool rules with terminal rule on send_message
|
||||
tool_rules = [
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[], tool_rules=tool_rules)
|
||||
|
||||
# Send message that should trigger send_message tool call
|
||||
response = await run_agent_step(
|
||||
agent_state=agent_state,
|
||||
input_messages=[MessageCreate(role="user", content="Please send me a simple message.")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert_sanity_checks(response)
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Find the send_message tool call and check request_heartbeat is False
|
||||
send_message_call = None
|
||||
for message in response.messages:
|
||||
if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message":
|
||||
send_message_call = message
|
||||
break
|
||||
|
||||
assert send_message_call is not None, "send_message tool call should be found"
|
||||
|
||||
# Parse the arguments and check request_heartbeat
|
||||
try:
|
||||
arguments = json.loads(send_message_call.tool_call.arguments)
|
||||
assert "request_heartbeat" in arguments, "request_heartbeat should be present in send_message arguments"
|
||||
assert arguments["request_heartbeat"] is False, "request_heartbeat should be False for terminal tool rule"
|
||||
|
||||
print(f"✓ Agent '{agent_name}' correctly set request_heartbeat=False for terminal send_message")
|
||||
except json.JSONDecodeError:
|
||||
pytest.fail("Failed to parse tool call arguments as JSON")
|
||||
finally:
|
||||
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
|
||||
|
||||
Reference in New Issue
Block a user