feat: Enable dynamic toggling of tool choice in v3 agent loop for OpenAI [LET-4564] (#5042)

* Add subsequent flag

* Finish integrating constrained/unconstrained toggling on v3 agent loop

* Update tests to run on v3

* Run lint
This commit is contained in:
Matthew Zhou
2025-09-30 17:18:47 -07:00
committed by Caren Thomas
parent c465da27e6
commit df5c997da0
11 changed files with 77 additions and 102 deletions

View File

@@ -51,6 +51,10 @@ class LettaAgentV3(LettaAgentV2):
* Support Gemini / OpenAI client
"""
def _initialize_state(self):
super()._initialize_state()
self._require_tool_call = False
@trace_method
async def step(
self,
@@ -279,6 +283,15 @@ class LettaAgentV3(LettaAgentV2):
try:
self.last_function_response = _load_last_function_response(messages)
valid_tools = await self._get_valid_tools()
require_tool_call = self.tool_rules_solver.should_force_tool_call()
if self._require_tool_call != require_tool_call:
if require_tool_call:
self.logger.info("switching to constrained mode (forcing tool call)")
else:
self.logger.info("switching to unconstrained mode (allowing non-tool responses)")
self._require_tool_call = require_tool_call
approval_request, approval_response = _maybe_get_approval_messages(messages)
if approval_request and approval_response:
tool_call = approval_request.tool_calls[0]
@@ -307,6 +320,7 @@ class LettaAgentV3(LettaAgentV2):
llm_config=self.agent_state.llm_config,
tools=valid_tools,
force_tool_call=force_tool_call,
requires_subsequent_tool_call=self._require_tool_call,
)
if dry_run:
yield request_data
@@ -590,8 +604,7 @@ class LettaAgentV3(LettaAgentV2):
elif tool_call is None:
# TODO could just hardcode the line here instead of calling the function...
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
# agent_state=agent_state,
# request_heartbeat=False,
agent_state=agent_state,
tool_call_name=None,
tool_rule_violated=False,
tool_rules_solver=tool_rules_solver,
@@ -705,8 +718,7 @@ class LettaAgentV3(LettaAgentV2):
# 4. Decide whether to keep stepping (focal section simplified)
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
# agent_state=agent_state,
# request_heartbeat=request_heartbeat,
agent_state=agent_state,
tool_call_name=tool_call_name,
tool_rule_violated=tool_rule_violated,
tool_rules_solver=tool_rules_solver,
@@ -753,8 +765,7 @@ class LettaAgentV3(LettaAgentV2):
@trace_method
def _decide_continuation(
self,
# agent_state: AgentState,
# request_heartbeat: bool,
agent_state: AgentState,
tool_call_name: Optional[str],
tool_rule_violated: bool,
tool_rules_solver: ToolRulesSolver,
@@ -771,19 +782,14 @@ class LettaAgentV3(LettaAgentV2):
2c. Called tool + tool rule violation (did not execute)
"""
continue_stepping = True # Default continue
continuation_reason: str | None = None
stop_reason: LettaStopReason | None = None
if tool_call_name is None:
# No tool call? End loop
return False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value)
else:
# If we have a tool call, we continue stepping
return True, None, None
# TODO support tool rules
# I think we can just uncomment the bellow?
if tool_rule_violated:
continue_stepping = True
continuation_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
@@ -791,8 +797,7 @@ class LettaAgentV3(LettaAgentV2):
tool_rules_solver.register_tool_call(tool_call_name)
if tool_rules_solver.is_terminal_tool(tool_call_name):
if continue_stepping:
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
continue_stepping = False
elif tool_rules_solver.has_children_tools(tool_call_name):
@@ -809,7 +814,7 @@ class LettaAgentV3(LettaAgentV2):
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
else:
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
if not continue_stepping and uncalled:
if uncalled:
continue_stepping = True
continuation_reason = (
f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [{', '.join(uncalled)}] to be called still."
@@ -817,7 +822,7 @@ class LettaAgentV3(LettaAgentV2):
stop_reason = None # reset were still going
return continue_stepping, continuation_reason, stop_reason
return continue_stepping, continuation_reason, stop_reason
@trace_method
async def _get_valid_tools(self):

View File

@@ -183,6 +183,7 @@ class AnthropicClient(LLMClientBase):
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise

View File

@@ -70,8 +70,9 @@ class BedrockClient(AnthropicClient):
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# remove disallowed fields
if "tool_choice" in data:
del data["tool_choice"]["disable_parallel_tool_use"]

View File

@@ -337,11 +337,12 @@ class DeepseekClient(OpenAIClient):
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
llm_config.put_inner_thoughts_in_kwargs = False
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
def add_functions_to_system_message(system_message: ChatMessage):
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in tools)} </available functions>"

View File

@@ -280,6 +280,7 @@ class GoogleVertexClient(LLMClientBase):
llm_config: LLMConfig,
tools: List[dict],
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
"""
Constructs a request object in the expected data format for this client.

View File

@@ -29,8 +29,9 @@ class GroqClient(OpenAIClient):
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# Groq validation - these fields are not supported and will cause 400 errors
# https://console.groq.com/docs/openai

View File

@@ -127,6 +127,7 @@ class LLMClientBase:
llm_config: LLMConfig,
tools: List[dict],
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
"""
Constructs a request object in the expected data format for this client.

View File

@@ -206,6 +206,7 @@ class OpenAIClient(LLMClientBase):
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
"""
Constructs a request object in the expected data format for the OpenAI Responses API.
@@ -224,14 +225,15 @@ class OpenAIClient(LLMClientBase):
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
model = None
# Default to auto, unless there's a forced tool call coming from above
# Default to auto, unless there's a forced tool call coming from above or requires_subsequent_tool_call is True
tool_choice = None
if tools: # only set tool_choice if tools exist
tool_choice = (
"auto"
if force_tool_call is None
else ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
)
if force_tool_call is not None:
tool_choice = {"type": "function", "name": force_tool_call}
elif requires_subsequent_tool_call:
tool_choice = "required"
else:
tool_choice = "auto"
# Convert the tools from the ChatCompletions style to the Responses style
if tools:
@@ -352,6 +354,7 @@ class OpenAIClient(LLMClientBase):
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
"""
Constructs a request object in the expected data format for the OpenAI API.
@@ -364,6 +367,7 @@ class OpenAIClient(LLMClientBase):
llm_config=llm_config,
tools=tools,
force_tool_call=force_tool_call,
requires_subsequent_tool_call=requires_subsequent_tool_call,
)
if agent_type == AgentType.letta_v1_agent:
@@ -407,15 +411,16 @@ class OpenAIClient(LLMClientBase):
# TODO: This vllm checking is very brittle and is a patch at most
tool_choice = None
if tools: # only set tool_choice if tools exist
if self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
if force_tool_call is not None:
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
elif requires_subsequent_tool_call:
tool_choice = "required"
elif self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
tool_choice = "auto"
else:
# only set if tools is non-Null
tool_choice = "required"
if force_tool_call is not None:
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
data = ChatCompletionRequest(
model=model,
messages=fill_image_content_in_messages(openai_message_list, messages),

View File

@@ -29,8 +29,9 @@ class XAIClient(OpenAIClient):
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# Specific bug for the mini models (as of Apr 14, 2025)
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}

View File

@@ -0,0 +1,7 @@
{
"context_window": 1047576,
"model": "gpt-4.1-2025-04-14",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null
}

View File

@@ -5,6 +5,7 @@ import uuid
import pytest
from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.agents.letta_agent_v3 import LettaAgentV3
from letta.config import LettaConfig
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.message import MessageCreate
@@ -12,7 +13,6 @@ from letta.schemas.run import Run
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
from letta.server.server import SyncServer
from letta.services.run_manager import RunManager
from letta.services.telemetry_manager import NoopTelemetryManager
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
@@ -25,7 +25,9 @@ from tests.utils import create_tool_from_func
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
OPENAI_CONFIG = "tests/configs/llm_model_configs/openai-gpt-4.1.json"
CLAUDE_SONNET_CONFIG = "tests/configs/llm_model_configs/claude-4-sonnet.json"
@pytest.fixture()
@@ -38,6 +40,12 @@ async def server():
return server
@pytest.fixture()
def default_config_file():
"""Provides the default config file path for tests."""
return OPENAI_CONFIG
@pytest.fixture(scope="function")
async def first_secret_tool(server):
def first_secret_word():
@@ -241,7 +249,7 @@ async def default_user(server):
async def run_agent_step(agent_state, input_messages, actor):
"""Helper function to run agent step using LettaAgent directly instead of server.send_messages."""
agent_loop = LettaAgentV2(
agent_loop = LettaAgentV3(
agent_state=agent_state,
actor=actor,
)
@@ -283,7 +291,7 @@ async def test_single_path_agent_tool_call_graph(
]
# Make agent state
agent_state = await setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
response = await run_agent_step(
agent_state=agent_state,
input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")],
@@ -321,9 +329,8 @@ async def test_single_path_agent_tool_call_graph(
@pytest.mark.parametrize(
"config_file",
[
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
"tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json",
"tests/configs/llm_model_configs/openai-gpt-4o.json",
CLAUDE_SONNET_CONFIG,
OPENAI_CONFIG,
],
)
@pytest.mark.parametrize("init_tools_case", ["single", "multiple"])
@@ -383,13 +390,12 @@ async def test_claude_initial_tool_rule_enforced(
TerminalToolRule(tool_name=second_secret_tool.name),
]
tools = [first_secret_tool, second_secret_tool]
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
for i in range(3):
agent_uuid = str(uuid.uuid4())
agent_state = await setup_agent(
server,
anthropic_config_file,
CLAUDE_SONNET_CONFIG,
agent_uuid=agent_uuid,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
@@ -426,8 +432,7 @@ async def test_claude_initial_tool_rule_enforced(
@pytest.mark.parametrize(
"config_file",
[
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
"tests/configs/llm_model_configs/openai-gpt-4o.json",
OPENAI_CONFIG,
],
)
@pytest.mark.asyncio
@@ -508,13 +513,12 @@ async def test_init_tool_rule_always_fails(
include_base_tools,
):
"""Test behavior when InitToolRule invokes a tool that always fails."""
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_uuid = str(uuid.uuid4())
tool_rule = InitToolRule(tool_name=auto_error_tool.name)
agent_state = await setup_agent(
server,
config_file,
OPENAI_CONFIG,
agent_uuid=agent_uuid,
tool_ids=[auto_error_tool.id],
tool_rules=[tool_rule],
@@ -535,7 +539,6 @@ async def test_init_tool_rule_always_fails(
@pytest.mark.asyncio
async def test_continue_tool_rule(server, default_user):
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_uuid = str(uuid.uuid4())
tools = [
@@ -551,7 +554,7 @@ async def test_continue_tool_rule(server, default_user):
agent_state = await setup_agent(
server,
config_file,
CLAUDE_SONNET_CONFIG,
agent_uuid,
tool_ids=tool_ids,
tool_rules=tool_rules,
@@ -618,7 +621,7 @@ async def test_continue_tool_rule(server, default_user):
# ]
# tools = [flip_coin_tool, reveal_secret]
#
# config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
# config_file = CLAUDE_SONNET_CONFIG
# agent_state = await setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
#
@@ -824,7 +827,6 @@ async def test_continue_tool_rule(server, default_user):
async def test_single_required_before_exit_tool(server, disable_e2b_api_key, save_data_tool, default_user):
"""Test that agent is forced to call a single required-before-exit tool before ending."""
agent_name = "required_exit_single_tool_agent"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tools and rules
tools = [save_data_tool]
@@ -835,7 +837,7 @@ async def test_single_required_before_exit_tool(server, disable_e2b_api_key, sav
]
# Create agent
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Send message that would normally cause exit
response = await run_agent_step(
@@ -866,7 +868,6 @@ async def test_single_required_before_exit_tool(server, disable_e2b_api_key, sav
async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key, save_data_tool, cleanup_temp_files_tool, default_user):
"""Test that agent calls all required-before-exit tools before ending."""
agent_name = "required_exit_multi_tool_agent"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tools and rules
tools = [save_data_tool, cleanup_temp_files_tool]
@@ -878,7 +879,7 @@ async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key,
]
# Create agent
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Send message that would normally cause exit
response = await run_agent_step(
@@ -911,7 +912,6 @@ async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key,
async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key, first_secret_tool, save_data_tool, default_user):
"""Test required-before-exit rules work alongside other tool rules."""
agent_name = "required_exit_with_rules_agent"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tools and rules - combine with child tool rules
tools = [first_secret_tool, save_data_tool]
@@ -923,7 +923,7 @@ async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key
]
# Create agent
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Send message that would trigger tool flow
response = await run_agent_step(
@@ -956,7 +956,6 @@ async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key
async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_key, save_data_tool, default_user):
"""Test that agent can exit normally when required tools are called during regular operation."""
agent_name = "required_exit_normal_flow_agent"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tools and rules
tools = [save_data_tool]
@@ -967,7 +966,7 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
]
# Create agent
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Send message that explicitly mentions calling the required tool
response = await run_agent_step(
@@ -990,51 +989,3 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
assert len(send_message_calls) == 1, "Should call send_message exactly once"
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_terminal_tool_rule_send_message_request_heartbeat_false(server, disable_e2b_api_key, default_user):
"""Test that when there's a terminal tool rule on send_message, the tool call has request_heartbeat=False."""
agent_name = "terminal_send_message_heartbeat_test"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tool rules with terminal rule on send_message
tool_rules = [
TerminalToolRule(tool_name="send_message"),
]
# Create agent
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[], tool_rules=tool_rules)
# Send message that should trigger send_message tool call
response = await run_agent_step(
agent_state=agent_state,
input_messages=[MessageCreate(role="user", content="Please send me a simple message.")],
actor=default_user,
)
# Assertions
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "send_message")
# Find the send_message tool call and check request_heartbeat is False
send_message_call = None
for message in response.messages:
if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message":
send_message_call = message
break
assert send_message_call is not None, "send_message tool call should be found"
# Parse the arguments and check request_heartbeat
try:
arguments = json.loads(send_message_call.tool_call.arguments)
assert "request_heartbeat" in arguments, "request_heartbeat should be present in send_message arguments"
assert arguments["request_heartbeat"] is False, "request_heartbeat should be False for terminal tool rule"
print(f"✓ Agent '{agent_name}' correctly set request_heartbeat=False for terminal send_message")
except json.JSONDecodeError:
pytest.fail("Failed to parse tool call arguments as JSON")
finally:
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)