feat: Enable dynamic toggling of tool choice in v3 agent loop for OpenAI [LET-4564] (#5042)
* Add subsequent flag * Finish integrating constrained/unconstrained toggling on v3 agent loop * Update tests to run on v3 * Run lint
This commit is contained in:
committed by
Caren Thomas
parent
c465da27e6
commit
df5c997da0
@@ -51,6 +51,10 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
* Support Gemini / OpenAI client
|
* Support Gemini / OpenAI client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _initialize_state(self):
|
||||||
|
super()._initialize_state()
|
||||||
|
self._require_tool_call = False
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
async def step(
|
async def step(
|
||||||
self,
|
self,
|
||||||
@@ -279,6 +283,15 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
try:
|
try:
|
||||||
self.last_function_response = _load_last_function_response(messages)
|
self.last_function_response = _load_last_function_response(messages)
|
||||||
valid_tools = await self._get_valid_tools()
|
valid_tools = await self._get_valid_tools()
|
||||||
|
require_tool_call = self.tool_rules_solver.should_force_tool_call()
|
||||||
|
|
||||||
|
if self._require_tool_call != require_tool_call:
|
||||||
|
if require_tool_call:
|
||||||
|
self.logger.info("switching to constrained mode (forcing tool call)")
|
||||||
|
else:
|
||||||
|
self.logger.info("switching to unconstrained mode (allowing non-tool responses)")
|
||||||
|
self._require_tool_call = require_tool_call
|
||||||
|
|
||||||
approval_request, approval_response = _maybe_get_approval_messages(messages)
|
approval_request, approval_response = _maybe_get_approval_messages(messages)
|
||||||
if approval_request and approval_response:
|
if approval_request and approval_response:
|
||||||
tool_call = approval_request.tool_calls[0]
|
tool_call = approval_request.tool_calls[0]
|
||||||
@@ -307,6 +320,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
llm_config=self.agent_state.llm_config,
|
llm_config=self.agent_state.llm_config,
|
||||||
tools=valid_tools,
|
tools=valid_tools,
|
||||||
force_tool_call=force_tool_call,
|
force_tool_call=force_tool_call,
|
||||||
|
requires_subsequent_tool_call=self._require_tool_call,
|
||||||
)
|
)
|
||||||
if dry_run:
|
if dry_run:
|
||||||
yield request_data
|
yield request_data
|
||||||
@@ -590,8 +604,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
elif tool_call is None:
|
elif tool_call is None:
|
||||||
# TODO could just hardcode the line here instead of calling the function...
|
# TODO could just hardcode the line here instead of calling the function...
|
||||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||||
# agent_state=agent_state,
|
agent_state=agent_state,
|
||||||
# request_heartbeat=False,
|
|
||||||
tool_call_name=None,
|
tool_call_name=None,
|
||||||
tool_rule_violated=False,
|
tool_rule_violated=False,
|
||||||
tool_rules_solver=tool_rules_solver,
|
tool_rules_solver=tool_rules_solver,
|
||||||
@@ -705,8 +718,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
|
|
||||||
# 4. Decide whether to keep stepping (focal section simplified)
|
# 4. Decide whether to keep stepping (focal section simplified)
|
||||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||||
# agent_state=agent_state,
|
agent_state=agent_state,
|
||||||
# request_heartbeat=request_heartbeat,
|
|
||||||
tool_call_name=tool_call_name,
|
tool_call_name=tool_call_name,
|
||||||
tool_rule_violated=tool_rule_violated,
|
tool_rule_violated=tool_rule_violated,
|
||||||
tool_rules_solver=tool_rules_solver,
|
tool_rules_solver=tool_rules_solver,
|
||||||
@@ -753,8 +765,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
@trace_method
|
@trace_method
|
||||||
def _decide_continuation(
|
def _decide_continuation(
|
||||||
self,
|
self,
|
||||||
# agent_state: AgentState,
|
agent_state: AgentState,
|
||||||
# request_heartbeat: bool,
|
|
||||||
tool_call_name: Optional[str],
|
tool_call_name: Optional[str],
|
||||||
tool_rule_violated: bool,
|
tool_rule_violated: bool,
|
||||||
tool_rules_solver: ToolRulesSolver,
|
tool_rules_solver: ToolRulesSolver,
|
||||||
@@ -771,19 +782,14 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
2c. Called tool + tool rule violation (did not execute)
|
2c. Called tool + tool rule violation (did not execute)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
continue_stepping = True # Default continue
|
||||||
continuation_reason: str | None = None
|
continuation_reason: str | None = None
|
||||||
stop_reason: LettaStopReason | None = None
|
stop_reason: LettaStopReason | None = None
|
||||||
|
|
||||||
if tool_call_name is None:
|
if tool_call_name is None:
|
||||||
# No tool call? End loop
|
# No tool call? End loop
|
||||||
return False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
return False, None, LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# If we have a tool call, we continue stepping
|
|
||||||
return True, None, None
|
|
||||||
|
|
||||||
# TODO support tool rules
|
|
||||||
# I think we can just uncomment the bellow?
|
|
||||||
if tool_rule_violated:
|
if tool_rule_violated:
|
||||||
continue_stepping = True
|
continue_stepping = True
|
||||||
continuation_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
|
continuation_reason = f"{NON_USER_MSG_PREFIX}Continuing: tool rule violation."
|
||||||
@@ -791,8 +797,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
tool_rules_solver.register_tool_call(tool_call_name)
|
tool_rules_solver.register_tool_call(tool_call_name)
|
||||||
|
|
||||||
if tool_rules_solver.is_terminal_tool(tool_call_name):
|
if tool_rules_solver.is_terminal_tool(tool_call_name):
|
||||||
if continue_stepping:
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
|
||||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.tool_rule.value)
|
|
||||||
continue_stepping = False
|
continue_stepping = False
|
||||||
|
|
||||||
elif tool_rules_solver.has_children_tools(tool_call_name):
|
elif tool_rules_solver.has_children_tools(tool_call_name):
|
||||||
@@ -809,7 +814,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
||||||
else:
|
else:
|
||||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||||||
if not continue_stepping and uncalled:
|
if uncalled:
|
||||||
continue_stepping = True
|
continue_stepping = True
|
||||||
continuation_reason = (
|
continuation_reason = (
|
||||||
f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [{', '.join(uncalled)}] to be called still."
|
f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [{', '.join(uncalled)}] to be called still."
|
||||||
@@ -817,7 +822,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
|
|
||||||
stop_reason = None # reset – we’re still going
|
stop_reason = None # reset – we’re still going
|
||||||
|
|
||||||
return continue_stepping, continuation_reason, stop_reason
|
return continue_stepping, continuation_reason, stop_reason
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
async def _get_valid_tools(self):
|
async def _get_valid_tools(self):
|
||||||
|
|||||||
@@ -183,6 +183,7 @@ class AnthropicClient(LLMClientBase):
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
||||||
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise
|
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise
|
||||||
|
|||||||
@@ -70,8 +70,9 @@ class BedrockClient(AnthropicClient):
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||||
# remove disallowed fields
|
# remove disallowed fields
|
||||||
if "tool_choice" in data:
|
if "tool_choice" in data:
|
||||||
del data["tool_choice"]["disable_parallel_tool_use"]
|
del data["tool_choice"]["disable_parallel_tool_use"]
|
||||||
|
|||||||
@@ -337,11 +337,12 @@ class DeepseekClient(OpenAIClient):
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
|
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
|
||||||
llm_config.put_inner_thoughts_in_kwargs = False
|
llm_config.put_inner_thoughts_in_kwargs = False
|
||||||
|
|
||||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||||
|
|
||||||
def add_functions_to_system_message(system_message: ChatMessage):
|
def add_functions_to_system_message(system_message: ChatMessage):
|
||||||
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in tools)} </available functions>"
|
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in tools)} </available functions>"
|
||||||
|
|||||||
@@ -280,6 +280,7 @@ class GoogleVertexClient(LLMClientBase):
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: List[dict],
|
tools: List[dict],
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Constructs a request object in the expected data format for this client.
|
Constructs a request object in the expected data format for this client.
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ class GroqClient(OpenAIClient):
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||||
|
|
||||||
# Groq validation - these fields are not supported and will cause 400 errors
|
# Groq validation - these fields are not supported and will cause 400 errors
|
||||||
# https://console.groq.com/docs/openai
|
# https://console.groq.com/docs/openai
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ class LLMClientBase:
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: List[dict],
|
tools: List[dict],
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Constructs a request object in the expected data format for this client.
|
Constructs a request object in the expected data format for this client.
|
||||||
|
|||||||
@@ -206,6 +206,7 @@ class OpenAIClient(LLMClientBase):
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Constructs a request object in the expected data format for the OpenAI Responses API.
|
Constructs a request object in the expected data format for the OpenAI Responses API.
|
||||||
@@ -224,14 +225,15 @@ class OpenAIClient(LLMClientBase):
|
|||||||
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
|
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
|
||||||
model = None
|
model = None
|
||||||
|
|
||||||
# Default to auto, unless there's a forced tool call coming from above
|
# Default to auto, unless there's a forced tool call coming from above or requires_subsequent_tool_call is True
|
||||||
tool_choice = None
|
tool_choice = None
|
||||||
if tools: # only set tool_choice if tools exist
|
if tools: # only set tool_choice if tools exist
|
||||||
tool_choice = (
|
if force_tool_call is not None:
|
||||||
"auto"
|
tool_choice = {"type": "function", "name": force_tool_call}
|
||||||
if force_tool_call is None
|
elif requires_subsequent_tool_call:
|
||||||
else ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
tool_choice = "required"
|
||||||
)
|
else:
|
||||||
|
tool_choice = "auto"
|
||||||
|
|
||||||
# Convert the tools from the ChatCompletions style to the Responses style
|
# Convert the tools from the ChatCompletions style to the Responses style
|
||||||
if tools:
|
if tools:
|
||||||
@@ -352,6 +354,7 @@ class OpenAIClient(LLMClientBase):
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Constructs a request object in the expected data format for the OpenAI API.
|
Constructs a request object in the expected data format for the OpenAI API.
|
||||||
@@ -364,6 +367,7 @@ class OpenAIClient(LLMClientBase):
|
|||||||
llm_config=llm_config,
|
llm_config=llm_config,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
force_tool_call=force_tool_call,
|
force_tool_call=force_tool_call,
|
||||||
|
requires_subsequent_tool_call=requires_subsequent_tool_call,
|
||||||
)
|
)
|
||||||
|
|
||||||
if agent_type == AgentType.letta_v1_agent:
|
if agent_type == AgentType.letta_v1_agent:
|
||||||
@@ -407,15 +411,16 @@ class OpenAIClient(LLMClientBase):
|
|||||||
# TODO: This vllm checking is very brittle and is a patch at most
|
# TODO: This vllm checking is very brittle and is a patch at most
|
||||||
tool_choice = None
|
tool_choice = None
|
||||||
if tools: # only set tool_choice if tools exist
|
if tools: # only set tool_choice if tools exist
|
||||||
if self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
|
if force_tool_call is not None:
|
||||||
|
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||||
|
elif requires_subsequent_tool_call:
|
||||||
|
tool_choice = "required"
|
||||||
|
elif self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
|
||||||
tool_choice = "auto"
|
tool_choice = "auto"
|
||||||
else:
|
else:
|
||||||
# only set if tools is non-Null
|
# only set if tools is non-Null
|
||||||
tool_choice = "required"
|
tool_choice = "required"
|
||||||
|
|
||||||
if force_tool_call is not None:
|
|
||||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
|
||||||
|
|
||||||
data = ChatCompletionRequest(
|
data = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=fill_image_content_in_messages(openai_message_list, messages),
|
messages=fill_image_content_in_messages(openai_message_list, messages),
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ class XAIClient(OpenAIClient):
|
|||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
force_tool_call: Optional[str] = None,
|
force_tool_call: Optional[str] = None,
|
||||||
|
requires_subsequent_tool_call: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||||
|
|
||||||
# Specific bug for the mini models (as of Apr 14, 2025)
|
# Specific bug for the mini models (as of Apr 14, 2025)
|
||||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
||||||
|
|||||||
7
tests/configs/llm_model_configs/openai-gpt-4.1.json
Normal file
7
tests/configs/llm_model_configs/openai-gpt-4.1.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"context_window": 1047576,
|
||||||
|
"model": "gpt-4.1-2025-04-14",
|
||||||
|
"model_endpoint_type": "openai",
|
||||||
|
"model_endpoint": "https://api.openai.com/v1",
|
||||||
|
"model_wrapper": null
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import uuid
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||||
|
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||||
from letta.config import LettaConfig
|
from letta.config import LettaConfig
|
||||||
from letta.schemas.letta_message import ToolCallMessage
|
from letta.schemas.letta_message import ToolCallMessage
|
||||||
from letta.schemas.message import MessageCreate
|
from letta.schemas.message import MessageCreate
|
||||||
@@ -12,7 +13,6 @@ from letta.schemas.run import Run
|
|||||||
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
|
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
from letta.services.run_manager import RunManager
|
from letta.services.run_manager import RunManager
|
||||||
from letta.services.telemetry_manager import NoopTelemetryManager
|
|
||||||
from tests.helpers.endpoints_helper import (
|
from tests.helpers.endpoints_helper import (
|
||||||
assert_invoked_function_call,
|
assert_invoked_function_call,
|
||||||
assert_invoked_send_message_with_keyword,
|
assert_invoked_send_message_with_keyword,
|
||||||
@@ -25,7 +25,9 @@ from tests.utils import create_tool_from_func
|
|||||||
# Generate uuid for agent name for this example
|
# Generate uuid for agent name for this example
|
||||||
namespace = uuid.NAMESPACE_DNS
|
namespace = uuid.NAMESPACE_DNS
|
||||||
agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
|
agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
|
||||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
|
||||||
|
OPENAI_CONFIG = "tests/configs/llm_model_configs/openai-gpt-4.1.json"
|
||||||
|
CLAUDE_SONNET_CONFIG = "tests/configs/llm_model_configs/claude-4-sonnet.json"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@@ -38,6 +40,12 @@ async def server():
|
|||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def default_config_file():
|
||||||
|
"""Provides the default config file path for tests."""
|
||||||
|
return OPENAI_CONFIG
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def first_secret_tool(server):
|
async def first_secret_tool(server):
|
||||||
def first_secret_word():
|
def first_secret_word():
|
||||||
@@ -241,7 +249,7 @@ async def default_user(server):
|
|||||||
|
|
||||||
async def run_agent_step(agent_state, input_messages, actor):
|
async def run_agent_step(agent_state, input_messages, actor):
|
||||||
"""Helper function to run agent step using LettaAgent directly instead of server.send_messages."""
|
"""Helper function to run agent step using LettaAgent directly instead of server.send_messages."""
|
||||||
agent_loop = LettaAgentV2(
|
agent_loop = LettaAgentV3(
|
||||||
agent_state=agent_state,
|
agent_state=agent_state,
|
||||||
actor=actor,
|
actor=actor,
|
||||||
)
|
)
|
||||||
@@ -283,7 +291,7 @@ async def test_single_path_agent_tool_call_graph(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Make agent state
|
# Make agent state
|
||||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||||
response = await run_agent_step(
|
response = await run_agent_step(
|
||||||
agent_state=agent_state,
|
agent_state=agent_state,
|
||||||
input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")],
|
input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")],
|
||||||
@@ -321,9 +329,8 @@ async def test_single_path_agent_tool_call_graph(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"config_file",
|
"config_file",
|
||||||
[
|
[
|
||||||
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
|
CLAUDE_SONNET_CONFIG,
|
||||||
"tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json",
|
OPENAI_CONFIG,
|
||||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("init_tools_case", ["single", "multiple"])
|
@pytest.mark.parametrize("init_tools_case", ["single", "multiple"])
|
||||||
@@ -383,13 +390,12 @@ async def test_claude_initial_tool_rule_enforced(
|
|||||||
TerminalToolRule(tool_name=second_secret_tool.name),
|
TerminalToolRule(tool_name=second_secret_tool.name),
|
||||||
]
|
]
|
||||||
tools = [first_secret_tool, second_secret_tool]
|
tools = [first_secret_tool, second_secret_tool]
|
||||||
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
agent_uuid = str(uuid.uuid4())
|
agent_uuid = str(uuid.uuid4())
|
||||||
agent_state = await setup_agent(
|
agent_state = await setup_agent(
|
||||||
server,
|
server,
|
||||||
anthropic_config_file,
|
CLAUDE_SONNET_CONFIG,
|
||||||
agent_uuid=agent_uuid,
|
agent_uuid=agent_uuid,
|
||||||
tool_ids=[t.id for t in tools],
|
tool_ids=[t.id for t in tools],
|
||||||
tool_rules=tool_rules,
|
tool_rules=tool_rules,
|
||||||
@@ -426,8 +432,7 @@ async def test_claude_initial_tool_rule_enforced(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"config_file",
|
"config_file",
|
||||||
[
|
[
|
||||||
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
|
OPENAI_CONFIG,
|
||||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -508,13 +513,12 @@ async def test_init_tool_rule_always_fails(
|
|||||||
include_base_tools,
|
include_base_tools,
|
||||||
):
|
):
|
||||||
"""Test behavior when InitToolRule invokes a tool that always fails."""
|
"""Test behavior when InitToolRule invokes a tool that always fails."""
|
||||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
|
||||||
agent_uuid = str(uuid.uuid4())
|
agent_uuid = str(uuid.uuid4())
|
||||||
|
|
||||||
tool_rule = InitToolRule(tool_name=auto_error_tool.name)
|
tool_rule = InitToolRule(tool_name=auto_error_tool.name)
|
||||||
agent_state = await setup_agent(
|
agent_state = await setup_agent(
|
||||||
server,
|
server,
|
||||||
config_file,
|
OPENAI_CONFIG,
|
||||||
agent_uuid=agent_uuid,
|
agent_uuid=agent_uuid,
|
||||||
tool_ids=[auto_error_tool.id],
|
tool_ids=[auto_error_tool.id],
|
||||||
tool_rules=[tool_rule],
|
tool_rules=[tool_rule],
|
||||||
@@ -535,7 +539,6 @@ async def test_init_tool_rule_always_fails(
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_continue_tool_rule(server, default_user):
|
async def test_continue_tool_rule(server, default_user):
|
||||||
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
|
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
|
||||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
|
||||||
agent_uuid = str(uuid.uuid4())
|
agent_uuid = str(uuid.uuid4())
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
@@ -551,7 +554,7 @@ async def test_continue_tool_rule(server, default_user):
|
|||||||
|
|
||||||
agent_state = await setup_agent(
|
agent_state = await setup_agent(
|
||||||
server,
|
server,
|
||||||
config_file,
|
CLAUDE_SONNET_CONFIG,
|
||||||
agent_uuid,
|
agent_uuid,
|
||||||
tool_ids=tool_ids,
|
tool_ids=tool_ids,
|
||||||
tool_rules=tool_rules,
|
tool_rules=tool_rules,
|
||||||
@@ -618,7 +621,7 @@ async def test_continue_tool_rule(server, default_user):
|
|||||||
# ]
|
# ]
|
||||||
# tools = [flip_coin_tool, reveal_secret]
|
# tools = [flip_coin_tool, reveal_secret]
|
||||||
#
|
#
|
||||||
# config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
# config_file = CLAUDE_SONNET_CONFIG
|
||||||
# agent_state = await setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
# agent_state = await setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||||
# response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
|
# response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
|
||||||
#
|
#
|
||||||
@@ -824,7 +827,6 @@ async def test_continue_tool_rule(server, default_user):
|
|||||||
async def test_single_required_before_exit_tool(server, disable_e2b_api_key, save_data_tool, default_user):
|
async def test_single_required_before_exit_tool(server, disable_e2b_api_key, save_data_tool, default_user):
|
||||||
"""Test that agent is forced to call a single required-before-exit tool before ending."""
|
"""Test that agent is forced to call a single required-before-exit tool before ending."""
|
||||||
agent_name = "required_exit_single_tool_agent"
|
agent_name = "required_exit_single_tool_agent"
|
||||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
|
||||||
|
|
||||||
# Set up tools and rules
|
# Set up tools and rules
|
||||||
tools = [save_data_tool]
|
tools = [save_data_tool]
|
||||||
@@ -835,7 +837,7 @@ async def test_single_required_before_exit_tool(server, disable_e2b_api_key, sav
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create agent
|
# Create agent
|
||||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||||
|
|
||||||
# Send message that would normally cause exit
|
# Send message that would normally cause exit
|
||||||
response = await run_agent_step(
|
response = await run_agent_step(
|
||||||
@@ -866,7 +868,6 @@ async def test_single_required_before_exit_tool(server, disable_e2b_api_key, sav
|
|||||||
async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key, save_data_tool, cleanup_temp_files_tool, default_user):
|
async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key, save_data_tool, cleanup_temp_files_tool, default_user):
|
||||||
"""Test that agent calls all required-before-exit tools before ending."""
|
"""Test that agent calls all required-before-exit tools before ending."""
|
||||||
agent_name = "required_exit_multi_tool_agent"
|
agent_name = "required_exit_multi_tool_agent"
|
||||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
|
||||||
|
|
||||||
# Set up tools and rules
|
# Set up tools and rules
|
||||||
tools = [save_data_tool, cleanup_temp_files_tool]
|
tools = [save_data_tool, cleanup_temp_files_tool]
|
||||||
@@ -878,7 +879,7 @@ async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key,
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create agent
|
# Create agent
|
||||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||||
|
|
||||||
# Send message that would normally cause exit
|
# Send message that would normally cause exit
|
||||||
response = await run_agent_step(
|
response = await run_agent_step(
|
||||||
@@ -911,7 +912,6 @@ async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key,
|
|||||||
async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key, first_secret_tool, save_data_tool, default_user):
|
async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key, first_secret_tool, save_data_tool, default_user):
|
||||||
"""Test required-before-exit rules work alongside other tool rules."""
|
"""Test required-before-exit rules work alongside other tool rules."""
|
||||||
agent_name = "required_exit_with_rules_agent"
|
agent_name = "required_exit_with_rules_agent"
|
||||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
|
||||||
|
|
||||||
# Set up tools and rules - combine with child tool rules
|
# Set up tools and rules - combine with child tool rules
|
||||||
tools = [first_secret_tool, save_data_tool]
|
tools = [first_secret_tool, save_data_tool]
|
||||||
@@ -923,7 +923,7 @@ async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create agent
|
# Create agent
|
||||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||||
|
|
||||||
# Send message that would trigger tool flow
|
# Send message that would trigger tool flow
|
||||||
response = await run_agent_step(
|
response = await run_agent_step(
|
||||||
@@ -956,7 +956,6 @@ async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key
|
|||||||
async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_key, save_data_tool, default_user):
|
async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_key, save_data_tool, default_user):
|
||||||
"""Test that agent can exit normally when required tools are called during regular operation."""
|
"""Test that agent can exit normally when required tools are called during regular operation."""
|
||||||
agent_name = "required_exit_normal_flow_agent"
|
agent_name = "required_exit_normal_flow_agent"
|
||||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
|
||||||
|
|
||||||
# Set up tools and rules
|
# Set up tools and rules
|
||||||
tools = [save_data_tool]
|
tools = [save_data_tool]
|
||||||
@@ -967,7 +966,7 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create agent
|
# Create agent
|
||||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
agent_state = await setup_agent(server, OPENAI_CONFIG, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||||
|
|
||||||
# Send message that explicitly mentions calling the required tool
|
# Send message that explicitly mentions calling the required tool
|
||||||
response = await run_agent_step(
|
response = await run_agent_step(
|
||||||
@@ -990,51 +989,3 @@ async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_
|
|||||||
assert len(send_message_calls) == 1, "Should call send_message exactly once"
|
assert len(send_message_calls) == 1, "Should call send_message exactly once"
|
||||||
|
|
||||||
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
|
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(60)
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_terminal_tool_rule_send_message_request_heartbeat_false(server, disable_e2b_api_key, default_user):
|
|
||||||
"""Test that when there's a terminal tool rule on send_message, the tool call has request_heartbeat=False."""
|
|
||||||
agent_name = "terminal_send_message_heartbeat_test"
|
|
||||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
|
||||||
|
|
||||||
# Set up tool rules with terminal rule on send_message
|
|
||||||
tool_rules = [
|
|
||||||
TerminalToolRule(tool_name="send_message"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create agent
|
|
||||||
agent_state = await setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[], tool_rules=tool_rules)
|
|
||||||
|
|
||||||
# Send message that should trigger send_message tool call
|
|
||||||
response = await run_agent_step(
|
|
||||||
agent_state=agent_state,
|
|
||||||
input_messages=[MessageCreate(role="user", content="Please send me a simple message.")],
|
|
||||||
actor=default_user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert_sanity_checks(response)
|
|
||||||
assert_invoked_function_call(response.messages, "send_message")
|
|
||||||
|
|
||||||
# Find the send_message tool call and check request_heartbeat is False
|
|
||||||
send_message_call = None
|
|
||||||
for message in response.messages:
|
|
||||||
if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message":
|
|
||||||
send_message_call = message
|
|
||||||
break
|
|
||||||
|
|
||||||
assert send_message_call is not None, "send_message tool call should be found"
|
|
||||||
|
|
||||||
# Parse the arguments and check request_heartbeat
|
|
||||||
try:
|
|
||||||
arguments = json.loads(send_message_call.tool_call.arguments)
|
|
||||||
assert "request_heartbeat" in arguments, "request_heartbeat should be present in send_message arguments"
|
|
||||||
assert arguments["request_heartbeat"] is False, "request_heartbeat should be False for terminal tool rule"
|
|
||||||
|
|
||||||
print(f"✓ Agent '{agent_name}' correctly set request_heartbeat=False for terminal send_message")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pytest.fail("Failed to parse tool call arguments as JSON")
|
|
||||||
finally:
|
|
||||||
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user