From 54562d88d747739c3bdaffb118f17a3c31511c7e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 23 Jun 2025 17:02:40 -0700 Subject: [PATCH] feat: Add required before exit tool rule (#2977) --- letta/agents/letta_agent.py | 28 +- letta/helpers/converters.py | 3 + letta/helpers/tool_rule_solver.py | 33 +- letta/schemas/enums.py | 1 + letta/schemas/tool_rule.py | 30 +- letta/server/rest_api/utils.py | 21 +- letta/server/server.py | 1 + letta/system.py | 2 +- .../llm_model_configs/openai-gpt-4o.json | 10 +- tests/integration_test_agent_tool_graph.py | 345 ++++++++++++++---- tests/test_sources.py | 1 + tests/test_tool_rule_solver.py | 102 +++++- 12 files changed, 495 insertions(+), 82 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 4fba688d..9b1bac28 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -11,7 +11,7 @@ from opentelemetry.trace import Span from letta.agents.base_agent import BaseAgent from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id -from letta.constants import DEFAULT_MAX_STEPS +from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX from letta.errors import ContextWindowExceededError from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import AsyncTimer, get_utc_time, get_utc_timestamp_ns, ns_to_ms @@ -56,8 +56,6 @@ from letta.system import package_function_response from letta.types import JsonDict from letta.utils import log_telemetry, validate_function_response -logger = get_logger(__name__) - class LettaAgent(BaseAgent): @@ -98,6 +96,7 @@ class LettaAgent(BaseAgent): self.summarization_agent = None self.summary_block_label = summary_block_label self.max_summarization_retries = max_summarization_retries + self.logger = get_logger(agent_id) # TODO: Expand to more if enable_summarization and model_settings.openai_api_key: @@ -223,7 +222,7 @@ class LettaAgent(BaseAgent): elif response.choices[0].message.content: reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons else: - logger.info("No reasoning content found.") + self.logger.info("No reasoning content found.") reasoning = None persisted_messages, should_continue, stop_reason = await self._handle_ai_response( @@ -376,7 +375,7 @@ class LettaAgent(BaseAgent): elif response.choices[0].message.omitted_reasoning_content: reasoning = [OmittedReasoningContent()] else: - logger.info("No reasoning content found.") + self.logger.info("No reasoning content found.") reasoning = None persisted_messages, should_continue, stop_reason = await self._handle_ai_response( @@ -451,7 +450,7 @@ class LettaAgent(BaseAgent): actor=self.actor, ) except Exception as e: - logger.error(f"Failed to update agent's last run metrics: {e}") + self.logger.error(f"Failed to update agent's last run metrics: {e}") @trace_method async def step_stream( @@ -950,7 +949,7 @@ class LettaAgent(BaseAgent): request_heartbeat = tool_args.pop("request_heartbeat", False) if is_final_step: stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value) - logger.info("Agent has reached max steps.") + self.logger.info("Agent has reached max steps.") request_heartbeat = False else: # Pre-emptively pop out inner_thoughts @@ -1032,6 +1031,20 @@ class LettaAgent(BaseAgent): elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name): continue_stepping = True + # Check if required-before-exit tools have been called before allowing exit + heartbeat_reason = None # Default + uncalled_required_tools = tool_rules_solver.get_uncalled_required_tools() + if not continue_stepping and uncalled_required_tools: + continue_stepping = True + heartbeat_reason = ( + f"{NON_USER_MSG_PREFIX}Cannot finish, still need to call the following required tools: {', '.join(uncalled_required_tools)}" + ) + + # TODO: @caren is this right? + # reset stop reason since we ain't stopping! + stop_reason = None + self.logger.info(f"RequiredBeforeExitToolRule: Forcing agent continuation. Missing required tools: {uncalled_required_tools}") + # 5a. Persist Steps to DB # Following agent loop to persist this before messages # TODO (cliandy): determine what should match old loop w/provider_id @@ -1062,6 +1075,7 @@ class LettaAgent(BaseAgent): function_response=function_response_string, actor=self.actor, add_heartbeat_request_system_message=continue_stepping, + heartbeat_reason=heartbeat_reason, reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 9a8bceda..53ee1565 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -39,6 +39,7 @@ from letta.schemas.tool_rule import ( InitToolRule, MaxCountPerStepToolRule, ParentToolRule, + RequiredBeforeExitToolRule, TerminalToolRule, ToolRule, ) @@ -131,6 +132,8 @@ def deserialize_tool_rule( return MaxCountPerStepToolRule(**data) elif rule_type == ToolRuleType.parent_last_tool: return ParentToolRule(**data) + elif rule_type == ToolRuleType.required_before_exit: + return RequiredBeforeExitToolRule(**data) raise ValueError(f"Unknown ToolRule type: {rule_type}") diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index acd35c8a..ec32c113 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -12,6 +12,7 @@ from letta.schemas.tool_rule import ( InitToolRule, MaxCountPerStepToolRule, ParentToolRule, + RequiredBeforeExitToolRule, TerminalToolRule, ) @@ -41,6 +42,9 @@ class ToolRulesSolver(BaseModel): terminal_tool_rules: List[TerminalToolRule] = Field( default_factory=list, description="Terminal tool rules that end the agent loop if called." ) + required_before_exit_tool_rules: List[RequiredBeforeExitToolRule] = Field( + default_factory=list, description="Tool rules that must be called before the agent can exit." + ) tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.") def __init__( @@ -51,6 +55,7 @@ class ToolRulesSolver(BaseModel): child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None, parent_tool_rules: Optional[List[ParentToolRule]] = None, terminal_tool_rules: Optional[List[TerminalToolRule]] = None, + required_before_exit_tool_rules: Optional[List[RequiredBeforeExitToolRule]] = None, tool_call_history: Optional[List[str]] = None, **kwargs, ): @@ -60,6 +65,7 @@ class ToolRulesSolver(BaseModel): child_based_tool_rules=child_based_tool_rules or [], parent_tool_rules=parent_tool_rules or [], terminal_tool_rules=terminal_tool_rules or [], + required_before_exit_tool_rules=required_before_exit_tool_rules or [], tool_call_history=tool_call_history or [], **kwargs, ) @@ -88,6 +94,9 @@ class ToolRulesSolver(BaseModel): elif rule.type == ToolRuleType.parent_last_tool: assert isinstance(rule, ParentToolRule) self.parent_tool_rules.append(rule) + elif rule.type == ToolRuleType.required_before_exit: + assert isinstance(rule, RequiredBeforeExitToolRule) + self.required_before_exit_tool_rules.append(rule) def register_tool_call(self, tool_name: str): """Update the internal state to track tool call history.""" @@ -131,8 +140,10 @@ class ToolRulesSolver(BaseModel): return list(final_allowed_tools) def is_terminal_tool(self, tool_name: str) -> bool: - """Check if the tool is defined as a terminal tool in the terminal tool rules.""" - return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules) + """Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules.""" + return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules) or any( + rule.tool_name == tool_name for rule in self.required_before_exit_tool_rules + ) def has_children_tools(self, tool_name): """Check if the tool has children tools""" @@ -142,6 +153,24 @@ class ToolRulesSolver(BaseModel): """Check if the tool is defined as a continue tool in the tool rules.""" return any(rule.tool_name == tool_name for rule in self.continue_tool_rules) + def has_required_tools_been_called(self) -> bool: + """Check if all required-before-exit tools have been called.""" + return len(self.get_uncalled_required_tools()) == 0 + + def get_uncalled_required_tools(self) -> List[str]: + """Get the list of required-before-exit tools that have not been called yet.""" + if not self.required_before_exit_tool_rules: + return [] # No required tools means no uncalled tools + + required_tool_names = {rule.tool_name for rule in self.required_before_exit_tool_rules} + called_tool_names = set(self.tool_call_history) + + return list(required_tool_names - called_tool_names) + + def get_ending_tool_names(self) -> List[str]: + """Get the names of tools that are required before exit.""" + return [rule.tool_name for rule in self.required_before_exit_tool_rules] + def compile_tool_rule_prompts(self) -> Optional[Block]: """ Compile prompt templates from all tool rules into an ephemeral Block. diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 627fc3fc..d4c4714e 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -86,6 +86,7 @@ class ToolRuleType(str, Enum): constrain_child_tools = "constrain_child_tools" max_count_per_step = "max_count_per_step" parent_last_tool = "parent_last_tool" + required_before_exit = "required_before_exit" # tool must be called before loop can exit class FileProcessingStatus(str, Enum): diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index 94dc4978..fef64323 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -181,6 +181,25 @@ class ContinueToolRule(BaseToolRule): ) +class RequiredBeforeExitToolRule(BaseToolRule): + """ + Represents a tool rule configuration where this tool must be called before the agent loop can exit. + """ + + type: Literal[ToolRuleType.required_before_exit] = ToolRuleType.required_before_exit + prompt_template: Optional[str] = Field( + default="{{ tool_name }} must be called before ending the conversation", + description="Optional Jinja2 template for generating agent prompt about this tool rule.", + ) + + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: + """Returns all available tools - the logic for preventing exit is handled elsewhere.""" + return available_tools + + def _get_default_template(self) -> Optional[str]: + return "{{ tool_name }} must be called before ending the conversation" + + class MaxCountPerStepToolRule(BaseToolRule): """ Represents a tool rule configuration which constrains the total number of times this tool can be invoked in a single step. @@ -208,6 +227,15 @@ class MaxCountPerStepToolRule(BaseToolRule): ToolRule = Annotated[ - Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule], + Union[ + ChildToolRule, + InitToolRule, + TerminalToolRule, + ConditionalToolRule, + ContinueToolRule, + RequiredBeforeExitToolRule, + MaxCountPerStepToolRule, + ParentToolRule, + ], Field(discriminator="type"), ] diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index cefaafa1..271ae3bc 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -194,6 +194,7 @@ def create_letta_messages_from_llm_response( function_response: Optional[str], actor: User, add_heartbeat_request_system_message: bool = False, + heartbeat_reason: Optional[str] = None, reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, llm_batch_item_id: Optional[str] = None, @@ -254,7 +255,12 @@ def create_letta_messages_from_llm_response( if add_heartbeat_request_system_message: heartbeat_system_message = create_heartbeat_system_message( - agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor, llm_batch_item_id=llm_batch_item_id + agent_id=agent_id, + model=model, + function_call_success=function_call_success, + actor=actor, + llm_batch_item_id=llm_batch_item_id, + heartbeat_reason=heartbeat_reason, ) messages.append(heartbeat_system_message) @@ -265,9 +271,18 @@ def create_letta_messages_from_llm_response( def create_heartbeat_system_message( - agent_id: str, model: str, function_call_success: bool, actor: User, llm_batch_item_id: Optional[str] = None + agent_id: str, + model: str, + function_call_success: bool, + actor: User, + llm_batch_item_id: Optional[str] = None, + heartbeat_reason: Optional[str] = None, ) -> Message: - text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE + if heartbeat_reason: + text_content = heartbeat_reason + else: + text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE + heartbeat_system_message = Message( role=MessageRole.user, content=[TextContent(text=get_heartbeat(text_content))], diff --git a/letta/server/server.py b/letta/server/server.py index 89c39892..b9d50a8c 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -710,6 +710,7 @@ class SyncServer(Server): # Run the agent state forward return self._step(actor=actor, agent_id=agent_id, input_messages=message) + # TODO: Deprecate this def send_messages( self, actor: User, diff --git a/letta/system.py b/letta/system.py index 05b83c08..06acb8f9 100644 --- a/letta/system.py +++ b/letta/system.py @@ -87,7 +87,7 @@ def get_initial_boot_messages(version="startup"): return messages -def get_heartbeat(reason="Automated timer", include_location=False, location_name="San Francisco, CA, USA"): +def get_heartbeat(reason: str = "Automated timer", include_location: bool = False, location_name: str = "San Francisco, CA, USA"): # Package the message with time and location formatted_time = get_local_time() packaged_message = { diff --git a/tests/configs/llm_model_configs/openai-gpt-4o.json b/tests/configs/llm_model_configs/openai-gpt-4o.json index 8e2cd44a..85c6b3ac 100644 --- a/tests/configs/llm_model_configs/openai-gpt-4o.json +++ b/tests/configs/llm_model_configs/openai-gpt-4o.json @@ -1,7 +1,7 @@ { - "context_window": 8192, - "model": "gpt-4o", - "model_endpoint_type": "openai", - "model_endpoint": "https://api.openai.com/v1", - "model_wrapper": null + "context_window": 32000, + "model": "gpt-4o", + "model_endpoint_type": "openai", + "model_endpoint": "https://api.openai.com/v1", + "model_wrapper": null } diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 6c413dcf..7698e370 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -1,15 +1,15 @@ -import time +import asyncio import uuid import pytest +from letta.agents.letta_agent import LettaAgent from letta.config import LettaConfig from letta.schemas.letta_message import ToolCallMessage -from letta.schemas.letta_response import LettaResponse -from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import MessageCreate -from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule from letta.server.server import SyncServer +from letta.services.telemetry_manager import NoopTelemetryManager from tests.helpers.endpoints_helper import ( assert_invoked_function_call, assert_invoked_send_message_with_keyword, @@ -25,6 +25,13 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph")) config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() + + @pytest.fixture() def server(): config = LettaConfig.load() @@ -181,13 +188,83 @@ def auto_error_tool(server): yield tool +@pytest.fixture(scope="function") +def save_data_tool(server): + def save_data(): + """ + Saves important data before exiting. + + Returns: + str: Confirmation that data was saved. + """ + return "Data saved successfully" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=save_data), actor=actor) + yield tool + + +@pytest.fixture(scope="function") +def cleanup_temp_files_tool(server): + def cleanup_temp_files(): + """ + Cleans up temporary files before exiting. + + Returns: + str: Confirmation that cleanup was completed. + """ + return "Temporary files cleaned up" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=cleanup_temp_files), actor=actor) + yield tool + + +@pytest.fixture(scope="function") +def validate_work_tool(server): + def validate_work(): + """ + Validates that work is complete before exiting. + + Returns: + str: Validation result. + """ + return "Work validation passed" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=validate_work), actor=actor) + yield tool + + @pytest.fixture def default_user(server): yield server.user_manager.get_user_or_default() +async def run_agent_step(server, agent_id, input_messages, actor): + """Helper function to run agent step using LettaAgent directly instead of server.send_messages.""" + agent_loop = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + job_manager=server.job_manager, + passage_manager=server.passage_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=NoopTelemetryManager(), + ) + + return await agent_loop.step( + input_messages, + max_steps=50, + use_assistant_message=False, + ) + + @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely -def test_single_path_agent_tool_call_graph( +@pytest.mark.asyncio +async def test_single_path_agent_tool_call_graph( server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool, default_user ): cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) @@ -207,18 +284,11 @@ def test_single_path_agent_tool_call_graph( # Make agent state agent_state = setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - usage_stats = server.send_messages( - actor=default_user, + response = await run_agent_step( + server=server, agent_id=agent_state.id, input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")], - ) - messages = [message for step_messages in usage_stats.steps_messages for message in step_messages] - letta_messages = [] - for m in messages: - letta_messages += m.to_letta_messages() - - response = LettaResponse( - messages=letta_messages, stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), usage=usage_stats + actor=default_user, ) # Make checks @@ -299,7 +369,8 @@ def test_check_tool_rules_with_different_models_parametrized( @pytest.mark.timeout(180) -def test_claude_initial_tool_rule_enforced( +@pytest.mark.asyncio +async def test_claude_initial_tool_rule_enforced( server, disable_e2b_api_key, first_secret_tool, @@ -325,20 +396,11 @@ def test_claude_initial_tool_rule_enforced( tool_rules=tool_rules, ) - usage_stats = server.send_messages( - actor=default_user, + response = await run_agent_step( + server=server, agent_id=agent_state.id, input_messages=[MessageCreate(role="user", content="What is the second secret word?")], - ) - messages = [m for step in usage_stats.steps_messages for m in step] - letta_messages = [] - for m in messages: - letta_messages += m.to_letta_messages() - - response = LettaResponse( - messages=letta_messages, - stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), - usage=usage_stats, + actor=default_user, ) assert_sanity_checks(response) @@ -359,7 +421,7 @@ def test_claude_initial_tool_rule_enforced( # Exponential backoff if i < 2: backoff_time = 10 * (2**i) - time.sleep(backoff_time) + await asyncio.sleep(backoff_time) @pytest.mark.timeout(60) @@ -370,7 +432,8 @@ def test_claude_initial_tool_rule_enforced( "tests/configs/llm_model_configs/openai-gpt-4o.json", ], ) -def test_agent_no_structured_output_with_one_child_tool_parametrized( +@pytest.mark.asyncio +async def test_agent_no_structured_output_with_one_child_tool_parametrized( server, disable_e2b_api_key, default_user, @@ -404,20 +467,11 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized( tool_rules=tool_rules, ) - usage_stats = server.send_messages( - actor=default_user, + response = await run_agent_step( + server=server, agent_id=agent_state.id, input_messages=[MessageCreate(role="user", content="hi. run archival memory search")], - ) - messages = [m for step in usage_stats.steps_messages for m in step] - letta_messages = [] - for m in messages: - letta_messages += m.to_letta_messages() - - response = LettaResponse( - messages=letta_messages, - stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), - usage=usage_stats, + actor=default_user, ) # Run assertions @@ -448,7 +502,8 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized( @pytest.mark.timeout(30) @pytest.mark.parametrize("include_base_tools", [False, True]) -def test_init_tool_rule_always_fails( +@pytest.mark.asyncio +async def test_init_tool_rule_always_fails( server, disable_e2b_api_key, auto_error_tool, @@ -469,17 +524,11 @@ def test_init_tool_rule_always_fails( include_base_tools=include_base_tools, ) - usage_stats = server.send_messages( - actor=default_user, + response = await run_agent_step( + server=server, agent_id=agent_state.id, input_messages=[MessageCreate(role="user", content="blah blah blah")], - ) - messages = [m for step in usage_stats.steps_messages for m in step] - letta_messages = [msg for m in messages for msg in m.to_letta_messages()] - response = LettaResponse( - messages=letta_messages, - stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), - usage=usage_stats, + actor=default_user, ) assert_invoked_function_call(response.messages, auto_error_tool.name) @@ -487,7 +536,8 @@ def test_init_tool_rule_always_fails( cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) -def test_continue_tool_rule(server, default_user): +@pytest.mark.asyncio +async def test_continue_tool_rule(server, default_user): """Test the continue tool rule by forcing send_message to loop before ending with core_memory_append.""" config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" agent_uuid = str(uuid.uuid4()) @@ -512,17 +562,11 @@ def test_continue_tool_rule(server, default_user): include_base_tool_rules=False, ) - usage_stats = server.send_messages( - actor=default_user, + response = await run_agent_step( + server=server, agent_id=agent_state.id, input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")], - ) - messages = [m for step in usage_stats.steps_messages for m in step] - letta_messages = [msg for m in messages for msg in m.to_letta_messages()] - response = LettaResponse( - messages=letta_messages, - stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), - usage=usage_stats, + actor=default_user, ) assert_invoked_function_call(response.messages, "send_message") @@ -775,3 +819,180 @@ def test_continue_tool_rule(server, default_user): # assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin" # # cleanup(client, agent_uuid=agent_state.id) + + +@pytest.mark.timeout(60) +@pytest.mark.asyncio +async def test_single_required_before_exit_tool(server, disable_e2b_api_key, save_data_tool, default_user): + """Test that agent is forced to call a single required-before-exit tool before ending.""" + agent_name = "required_exit_single_tool_agent" + config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" + + # Set up tools and rules + tools = [save_data_tool] + tool_rules = [ + InitToolRule(tool_name="send_message"), + RequiredBeforeExitToolRule(tool_name="save_data"), + TerminalToolRule(tool_name="send_message"), + ] + + # Create agent + agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + + # Send message that would normally cause exit + response = await run_agent_step( + server=server, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="Please finish your work and send me a message.")], + actor=default_user, + ) + + # Assertions + assert_sanity_checks(response) + assert_invoked_function_call(response.messages, "save_data") + assert_invoked_function_call(response.messages, "send_message") + + # The key test is that both tools were called - the agent was forced to call save_data + # even when it tried to exit early with send_message + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"] + send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"] + + assert len(save_data_calls) >= 1, "save_data should be called at least once" + assert len(send_message_calls) >= 1, "send_message should be called at least once" + + print(f"✓ Agent '{agent_name}' successfully called required tool before exit") + + +@pytest.mark.timeout(60) +@pytest.mark.asyncio +async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key, save_data_tool, cleanup_temp_files_tool, default_user): + """Test that agent calls all required-before-exit tools before ending.""" + agent_name = "required_exit_multi_tool_agent" + config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" + + # Set up tools and rules + tools = [save_data_tool, cleanup_temp_files_tool] + tool_rules = [ + InitToolRule(tool_name="send_message"), + RequiredBeforeExitToolRule(tool_name="save_data"), + RequiredBeforeExitToolRule(tool_name="cleanup_temp_files"), + TerminalToolRule(tool_name="send_message"), + ] + + # Create agent + agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + + # Send message that would normally cause exit + response = await run_agent_step( + server=server, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="Complete all necessary tasks and then send me a message.")], + actor=default_user, + ) + + # Assertions + assert_sanity_checks(response) + assert_invoked_function_call(response.messages, "save_data") + assert_invoked_function_call(response.messages, "cleanup_temp_files") + assert_invoked_function_call(response.messages, "send_message") + + # Verify that all required tools were eventually called + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"] + cleanup_calls = [tc for tc in tool_calls if tc.tool_call.name == "cleanup_temp_files"] + send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"] + + assert len(save_data_calls) >= 1, "save_data should be called at least once" + assert len(cleanup_calls) >= 1, "cleanup_temp_files should be called at least once" + assert len(send_message_calls) >= 1, "send_message should be called at least once" + + print(f"✓ Agent '{agent_name}' successfully called all required tools before exit") + + +@pytest.mark.timeout(60) +@pytest.mark.asyncio +async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key, first_secret_tool, save_data_tool, default_user): + """Test required-before-exit rules work alongside other tool rules.""" + agent_name = "required_exit_with_rules_agent" + config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" + + # Set up tools and rules - combine with child tool rules + tools = [first_secret_tool, save_data_tool] + tool_rules = [ + InitToolRule(tool_name="first_secret_word"), + ChildToolRule(tool_name="first_secret_word", children=["send_message"]), + RequiredBeforeExitToolRule(tool_name="save_data"), + TerminalToolRule(tool_name="send_message"), + ] + + # Create agent + agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + + # Send message that would trigger tool flow + response = await run_agent_step( + server=server, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="Get the first secret word and then finish up.")], + actor=default_user, + ) + + # Assertions + assert_sanity_checks(response) + assert_invoked_function_call(response.messages, "first_secret_word") + assert_invoked_function_call(response.messages, "save_data") + assert_invoked_function_call(response.messages, "send_message") + + # Verify that all tools were called (first_secret_word due to InitToolRule, save_data due to RequiredBeforeExitToolRule) + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + first_secret_calls = [tc for tc in tool_calls if tc.tool_call.name == "first_secret_word"] + save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"] + send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"] + + assert len(first_secret_calls) >= 1, "first_secret_word should be called due to InitToolRule" + assert len(save_data_calls) >= 1, "save_data should be called due to RequiredBeforeExitToolRule" + assert len(send_message_calls) >= 1, "send_message should be called eventually" + + print(f"✓ Agent '{agent_name}' successfully handled mixed tool rules") + + +@pytest.mark.timeout(60) +@pytest.mark.asyncio +async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_key, save_data_tool, default_user): + """Test that agent can exit normally when required tools are called during regular operation.""" + agent_name = "required_exit_normal_flow_agent" + config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" + + # Set up tools and rules + tools = [save_data_tool] + tool_rules = [ + InitToolRule(tool_name="save_data"), + RequiredBeforeExitToolRule(tool_name="send_message"), + TerminalToolRule(tool_name="send_message"), + ] + + # Create agent + agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + + # Send message that explicitly mentions calling the required tool + response = await run_agent_step( + server=server, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="Please save data and then send me a message when done.")], + actor=default_user, + ) + + # Assertions + assert_sanity_checks(response) + assert_invoked_function_call(response.messages, "save_data") + assert_invoked_function_call(response.messages, "send_message") + + # Should not have excessive tool calls - agent should exit cleanly after requirements are met + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"] + send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"] + + assert len(save_data_calls) == 1, "Should call save_data exactly once" + assert len(send_message_calls) == 1, "Should call send_message exactly once" + + print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally") diff --git a/tests/test_sources.py b/tests/test_sources.py index a654ab4d..42a289de 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -61,6 +61,7 @@ def agent_state(client: LettaSDKClient): grep_tool = client.tools.list(name="grep")[0] agent_state = client.agents.create( + name="test_sources_agent", memory_blocks=[ CreateBlock( label="human", diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 39015113..bdff9e2e 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -2,7 +2,14 @@ import pytest from letta.helpers import ToolRulesSolver from letta.helpers.tool_rule_solver import ToolRuleValidationError -from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule +from letta.schemas.tool_rule import ( + ChildToolRule, + ConditionalToolRule, + InitToolRule, + MaxCountPerStepToolRule, + RequiredBeforeExitToolRule, + TerminalToolRule, +) # Constants for tool names used in the tests START_TOOL = "start_tool" @@ -12,6 +19,9 @@ HELPER_TOOL = "helper_tool" FINAL_TOOL = "final_tool" END_TOOL = "end_tool" UNRECOGNIZED_TOOL = "unrecognized_tool" +REQUIRED_TOOL_1 = "required_tool_1" +REQUIRED_TOOL_2 = "required_tool_2" +SAVE_TOOL = "save_tool" def test_get_allowed_tool_names_with_init_rules(): @@ -175,3 +185,93 @@ def test_max_count_per_step_tool_rule_resets_on_clear(): solver.clear_tool_history() assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should allow 'start_tool' again after clearing history" + + +def test_required_before_exit_tool_rule_has_required_tools_been_called(): + """Test has_required_tools_been_called() with no required tools.""" + solver = ToolRulesSolver(tool_rules=[]) + + assert solver.has_required_tools_been_called() is True, "Should return True when no required tools are defined" + + +def test_required_before_exit_tool_rule_single_required_tool(): + """Test with a single required-before-exit tool.""" + required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) + solver = ToolRulesSolver(tool_rules=[required_rule]) + + assert solver.has_required_tools_been_called() is False, "Should return False when required tool hasn't been called" + assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should return list with uncalled required tool" + + solver.register_tool_call(SAVE_TOOL) + + assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called" + assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called" + + +def test_required_before_exit_tool_rule_multiple_required_tools(): + """Test with multiple required-before-exit tools.""" + required_rule_1 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_1) + required_rule_2 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_2) + solver = ToolRulesSolver(tool_rules=[required_rule_1, required_rule_2]) + + assert solver.has_required_tools_been_called() is False, "Should return False when no required tools have been called" + uncalled_tools = solver.get_uncalled_required_tools() + assert set(uncalled_tools) == {REQUIRED_TOOL_1, REQUIRED_TOOL_2}, "Should return both uncalled required tools" + + # Call first required tool + solver.register_tool_call(REQUIRED_TOOL_1) + + assert solver.has_required_tools_been_called() is False, "Should return False when only one required tool has been called" + assert solver.get_uncalled_required_tools() == [REQUIRED_TOOL_2], "Should return remaining uncalled required tool" + + # Call second required tool + solver.register_tool_call(REQUIRED_TOOL_2) + + assert solver.has_required_tools_been_called() is True, "Should return True when all required tools have been called" + assert solver.get_uncalled_required_tools() == [], "Should return empty list when all required tools have been called" + + +def test_required_before_exit_tool_rule_mixed_with_other_tools(): + """Test required-before-exit tools mixed with other tool calls.""" + required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) + solver = ToolRulesSolver(tool_rules=[required_rule]) + + # Call other tools first + solver.register_tool_call(START_TOOL) + solver.register_tool_call(HELPER_TOOL) + + assert solver.has_required_tools_been_called() is False, "Should return False even after calling other tools" + assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should still show required tool as uncalled" + + # Call required tool + solver.register_tool_call(SAVE_TOOL) + + assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called" + assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called" + + +def test_required_before_exit_tool_rule_is_terminal(): + """Test that required-before-exit tools are considered terminal tools.""" + required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + solver = ToolRulesSolver(tool_rules=[required_rule, terminal_rule]) + + assert solver.is_terminal_tool(SAVE_TOOL) is True, "Required-before-exit tool should be considered terminal" + assert solver.is_terminal_tool(END_TOOL) is True, "Regular terminal tool should still be considered terminal" + assert solver.is_terminal_tool(START_TOOL) is False, "Non-terminal tool should not be considered terminal" + + +def test_required_before_exit_tool_rule_clear_history(): + """Test that clearing history resets the required tools state.""" + required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) + solver = ToolRulesSolver(tool_rules=[required_rule]) + + # Call required tool + solver.register_tool_call(SAVE_TOOL) + assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called" + + # Clear history + solver.clear_tool_history() + + assert solver.has_required_tools_been_called() is False, "Should return False after clearing history" + assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should show required tool as uncalled after clearing history"