diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py
index 4fba688d..9b1bac28 100644
--- a/letta/agents/letta_agent.py
+++ b/letta/agents/letta_agent.py
@@ -11,7 +11,7 @@ from opentelemetry.trace import Span
from letta.agents.base_agent import BaseAgent
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_no_persist_async, generate_step_id
-from letta.constants import DEFAULT_MAX_STEPS
+from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import AsyncTimer, get_utc_time, get_utc_timestamp_ns, ns_to_ms
@@ -56,8 +56,6 @@ from letta.system import package_function_response
from letta.types import JsonDict
from letta.utils import log_telemetry, validate_function_response
-logger = get_logger(__name__)
-
class LettaAgent(BaseAgent):
@@ -98,6 +96,7 @@ class LettaAgent(BaseAgent):
self.summarization_agent = None
self.summary_block_label = summary_block_label
self.max_summarization_retries = max_summarization_retries
+ self.logger = get_logger(agent_id)
# TODO: Expand to more
if enable_summarization and model_settings.openai_api_key:
@@ -223,7 +222,7 @@ class LettaAgent(BaseAgent):
elif response.choices[0].message.content:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
else:
- logger.info("No reasoning content found.")
+ self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
@@ -376,7 +375,7 @@ class LettaAgent(BaseAgent):
elif response.choices[0].message.omitted_reasoning_content:
reasoning = [OmittedReasoningContent()]
else:
- logger.info("No reasoning content found.")
+ self.logger.info("No reasoning content found.")
reasoning = None
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
@@ -451,7 +450,7 @@ class LettaAgent(BaseAgent):
actor=self.actor,
)
except Exception as e:
- logger.error(f"Failed to update agent's last run metrics: {e}")
+ self.logger.error(f"Failed to update agent's last run metrics: {e}")
@trace_method
async def step_stream(
@@ -950,7 +949,7 @@ class LettaAgent(BaseAgent):
request_heartbeat = tool_args.pop("request_heartbeat", False)
if is_final_step:
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
- logger.info("Agent has reached max steps.")
+ self.logger.info("Agent has reached max steps.")
request_heartbeat = False
else:
# Pre-emptively pop out inner_thoughts
@@ -1032,6 +1031,20 @@ class LettaAgent(BaseAgent):
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
continue_stepping = True
+ # Check if required-before-exit tools have been called before allowing exit
+ heartbeat_reason = None # Default
+ uncalled_required_tools = tool_rules_solver.get_uncalled_required_tools()
+ if not continue_stepping and uncalled_required_tools:
+ continue_stepping = True
+ heartbeat_reason = (
+ f"{NON_USER_MSG_PREFIX}Cannot finish, still need to call the following required tools: {', '.join(uncalled_required_tools)}"
+ )
+
+ # TODO: @caren is this right?
+ # reset stop reason since we ain't stopping!
+ stop_reason = None
+ self.logger.info(f"RequiredBeforeExitToolRule: Forcing agent continuation. Missing required tools: {uncalled_required_tools}")
+
# 5a. Persist Steps to DB
# Following agent loop to persist this before messages
# TODO (cliandy): determine what should match old loop w/provider_id
@@ -1062,6 +1075,7 @@ class LettaAgent(BaseAgent):
function_response=function_response_string,
actor=self.actor,
add_heartbeat_request_system_message=continue_stepping,
+ heartbeat_reason=heartbeat_reason,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py
index 9a8bceda..53ee1565 100644
--- a/letta/helpers/converters.py
+++ b/letta/helpers/converters.py
@@ -39,6 +39,7 @@ from letta.schemas.tool_rule import (
InitToolRule,
MaxCountPerStepToolRule,
ParentToolRule,
+ RequiredBeforeExitToolRule,
TerminalToolRule,
ToolRule,
)
@@ -131,6 +132,8 @@ def deserialize_tool_rule(
return MaxCountPerStepToolRule(**data)
elif rule_type == ToolRuleType.parent_last_tool:
return ParentToolRule(**data)
+ elif rule_type == ToolRuleType.required_before_exit:
+ return RequiredBeforeExitToolRule(**data)
raise ValueError(f"Unknown ToolRule type: {rule_type}")
diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py
index acd35c8a..ec32c113 100644
--- a/letta/helpers/tool_rule_solver.py
+++ b/letta/helpers/tool_rule_solver.py
@@ -12,6 +12,7 @@ from letta.schemas.tool_rule import (
InitToolRule,
MaxCountPerStepToolRule,
ParentToolRule,
+ RequiredBeforeExitToolRule,
TerminalToolRule,
)
@@ -41,6 +42,9 @@ class ToolRulesSolver(BaseModel):
terminal_tool_rules: List[TerminalToolRule] = Field(
default_factory=list, description="Terminal tool rules that end the agent loop if called."
)
+ required_before_exit_tool_rules: List[RequiredBeforeExitToolRule] = Field(
+ default_factory=list, description="Tool rules that must be called before the agent can exit."
+ )
tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
def __init__(
@@ -51,6 +55,7 @@ class ToolRulesSolver(BaseModel):
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
parent_tool_rules: Optional[List[ParentToolRule]] = None,
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
+ required_before_exit_tool_rules: Optional[List[RequiredBeforeExitToolRule]] = None,
tool_call_history: Optional[List[str]] = None,
**kwargs,
):
@@ -60,6 +65,7 @@ class ToolRulesSolver(BaseModel):
child_based_tool_rules=child_based_tool_rules or [],
parent_tool_rules=parent_tool_rules or [],
terminal_tool_rules=terminal_tool_rules or [],
+ required_before_exit_tool_rules=required_before_exit_tool_rules or [],
tool_call_history=tool_call_history or [],
**kwargs,
)
@@ -88,6 +94,9 @@ class ToolRulesSolver(BaseModel):
elif rule.type == ToolRuleType.parent_last_tool:
assert isinstance(rule, ParentToolRule)
self.parent_tool_rules.append(rule)
+ elif rule.type == ToolRuleType.required_before_exit:
+ assert isinstance(rule, RequiredBeforeExitToolRule)
+ self.required_before_exit_tool_rules.append(rule)
def register_tool_call(self, tool_name: str):
"""Update the internal state to track tool call history."""
@@ -131,8 +140,10 @@ class ToolRulesSolver(BaseModel):
return list(final_allowed_tools)
def is_terminal_tool(self, tool_name: str) -> bool:
- """Check if the tool is defined as a terminal tool in the terminal tool rules."""
- return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
+ """Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules."""
+ return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules) or any(
+ rule.tool_name == tool_name for rule in self.required_before_exit_tool_rules
+ )
def has_children_tools(self, tool_name):
"""Check if the tool has children tools"""
@@ -142,6 +153,24 @@ class ToolRulesSolver(BaseModel):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
+ def has_required_tools_been_called(self) -> bool:
+ """Check if all required-before-exit tools have been called."""
+ return len(self.get_uncalled_required_tools()) == 0
+
+ def get_uncalled_required_tools(self) -> List[str]:
+ """Get the list of required-before-exit tools that have not been called yet."""
+ if not self.required_before_exit_tool_rules:
+ return [] # No required tools means no uncalled tools
+
+ required_tool_names = {rule.tool_name for rule in self.required_before_exit_tool_rules}
+ called_tool_names = set(self.tool_call_history)
+
+ return list(required_tool_names - called_tool_names)
+
+ def get_ending_tool_names(self) -> List[str]:
+ """Get the names of tools that are required before exit."""
+ return [rule.tool_name for rule in self.required_before_exit_tool_rules]
+
def compile_tool_rule_prompts(self) -> Optional[Block]:
"""
Compile prompt templates from all tool rules into an ephemeral Block.
diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py
index 627fc3fc..d4c4714e 100644
--- a/letta/schemas/enums.py
+++ b/letta/schemas/enums.py
@@ -86,6 +86,7 @@ class ToolRuleType(str, Enum):
constrain_child_tools = "constrain_child_tools"
max_count_per_step = "max_count_per_step"
parent_last_tool = "parent_last_tool"
+ required_before_exit = "required_before_exit" # tool must be called before loop can exit
class FileProcessingStatus(str, Enum):
diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py
index 94dc4978..fef64323 100644
--- a/letta/schemas/tool_rule.py
+++ b/letta/schemas/tool_rule.py
@@ -181,6 +181,25 @@ class ContinueToolRule(BaseToolRule):
)
+class RequiredBeforeExitToolRule(BaseToolRule):
+ """
+ Represents a tool rule configuration where this tool must be called before the agent loop can exit.
+ """
+
+ type: Literal[ToolRuleType.required_before_exit] = ToolRuleType.required_before_exit
+ prompt_template: Optional[str] = Field(
+ default="{{ tool_name }} must be called before ending the conversation",
+ description="Optional Jinja2 template for generating agent prompt about this tool rule.",
+ )
+
+ def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
+ """Returns all available tools - the logic for preventing exit is handled elsewhere."""
+ return available_tools
+
+ def _get_default_template(self) -> Optional[str]:
+ return "{{ tool_name }} must be called before ending the conversation"
+
+
class MaxCountPerStepToolRule(BaseToolRule):
"""
Represents a tool rule configuration which constrains the total number of times this tool can be invoked in a single step.
@@ -208,6 +227,15 @@ class MaxCountPerStepToolRule(BaseToolRule):
ToolRule = Annotated[
- Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule],
+ Union[
+ ChildToolRule,
+ InitToolRule,
+ TerminalToolRule,
+ ConditionalToolRule,
+ ContinueToolRule,
+ RequiredBeforeExitToolRule,
+ MaxCountPerStepToolRule,
+ ParentToolRule,
+ ],
Field(discriminator="type"),
]
diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py
index cefaafa1..271ae3bc 100644
--- a/letta/server/rest_api/utils.py
+++ b/letta/server/rest_api/utils.py
@@ -194,6 +194,7 @@ def create_letta_messages_from_llm_response(
function_response: Optional[str],
actor: User,
add_heartbeat_request_system_message: bool = False,
+ heartbeat_reason: Optional[str] = None,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
llm_batch_item_id: Optional[str] = None,
@@ -254,7 +255,12 @@ def create_letta_messages_from_llm_response(
if add_heartbeat_request_system_message:
heartbeat_system_message = create_heartbeat_system_message(
- agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor, llm_batch_item_id=llm_batch_item_id
+ agent_id=agent_id,
+ model=model,
+ function_call_success=function_call_success,
+ actor=actor,
+ llm_batch_item_id=llm_batch_item_id,
+ heartbeat_reason=heartbeat_reason,
)
messages.append(heartbeat_system_message)
@@ -265,9 +271,18 @@ def create_letta_messages_from_llm_response(
def create_heartbeat_system_message(
- agent_id: str, model: str, function_call_success: bool, actor: User, llm_batch_item_id: Optional[str] = None
+ agent_id: str,
+ model: str,
+ function_call_success: bool,
+ actor: User,
+ llm_batch_item_id: Optional[str] = None,
+ heartbeat_reason: Optional[str] = None,
) -> Message:
- text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE
+ if heartbeat_reason:
+ text_content = heartbeat_reason
+ else:
+ text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE
+
heartbeat_system_message = Message(
role=MessageRole.user,
content=[TextContent(text=get_heartbeat(text_content))],
diff --git a/letta/server/server.py b/letta/server/server.py
index 89c39892..b9d50a8c 100644
--- a/letta/server/server.py
+++ b/letta/server/server.py
@@ -710,6 +710,7 @@ class SyncServer(Server):
# Run the agent state forward
return self._step(actor=actor, agent_id=agent_id, input_messages=message)
+ # TODO: Deprecate this
def send_messages(
self,
actor: User,
diff --git a/letta/system.py b/letta/system.py
index 05b83c08..06acb8f9 100644
--- a/letta/system.py
+++ b/letta/system.py
@@ -87,7 +87,7 @@ def get_initial_boot_messages(version="startup"):
return messages
-def get_heartbeat(reason="Automated timer", include_location=False, location_name="San Francisco, CA, USA"):
+def get_heartbeat(reason: str = "Automated timer", include_location: bool = False, location_name: str = "San Francisco, CA, USA"):
# Package the message with time and location
formatted_time = get_local_time()
packaged_message = {
diff --git a/tests/configs/llm_model_configs/openai-gpt-4o.json b/tests/configs/llm_model_configs/openai-gpt-4o.json
index 8e2cd44a..85c6b3ac 100644
--- a/tests/configs/llm_model_configs/openai-gpt-4o.json
+++ b/tests/configs/llm_model_configs/openai-gpt-4o.json
@@ -1,7 +1,7 @@
{
- "context_window": 8192,
- "model": "gpt-4o",
- "model_endpoint_type": "openai",
- "model_endpoint": "https://api.openai.com/v1",
- "model_wrapper": null
+ "context_window": 32000,
+ "model": "gpt-4o",
+ "model_endpoint_type": "openai",
+ "model_endpoint": "https://api.openai.com/v1",
+ "model_wrapper": null
}
diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py
index 6c413dcf..7698e370 100644
--- a/tests/integration_test_agent_tool_graph.py
+++ b/tests/integration_test_agent_tool_graph.py
@@ -1,15 +1,15 @@
-import time
+import asyncio
import uuid
import pytest
+from letta.agents.letta_agent import LettaAgent
from letta.config import LettaConfig
from letta.schemas.letta_message import ToolCallMessage
-from letta.schemas.letta_response import LettaResponse
-from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
-from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
+from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
from letta.server.server import SyncServer
+from letta.services.telemetry_manager import NoopTelemetryManager
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
@@ -25,6 +25,13 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
+@pytest.fixture(scope="module")
+def event_loop():
+ loop = asyncio.new_event_loop()
+ yield loop
+ loop.close()
+
+
@pytest.fixture()
def server():
config = LettaConfig.load()
@@ -181,13 +188,83 @@ def auto_error_tool(server):
yield tool
+@pytest.fixture(scope="function")
+def save_data_tool(server):
+ def save_data():
+ """
+ Saves important data before exiting.
+
+ Returns:
+ str: Confirmation that data was saved.
+ """
+ return "Data saved successfully"
+
+ actor = server.user_manager.get_user_or_default()
+ tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=save_data), actor=actor)
+ yield tool
+
+
+@pytest.fixture(scope="function")
+def cleanup_temp_files_tool(server):
+ def cleanup_temp_files():
+ """
+ Cleans up temporary files before exiting.
+
+ Returns:
+ str: Confirmation that cleanup was completed.
+ """
+ return "Temporary files cleaned up"
+
+ actor = server.user_manager.get_user_or_default()
+ tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=cleanup_temp_files), actor=actor)
+ yield tool
+
+
+@pytest.fixture(scope="function")
+def validate_work_tool(server):
+ def validate_work():
+ """
+ Validates that work is complete before exiting.
+
+ Returns:
+ str: Validation result.
+ """
+ return "Work validation passed"
+
+ actor = server.user_manager.get_user_or_default()
+ tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=validate_work), actor=actor)
+ yield tool
+
+
@pytest.fixture
def default_user(server):
yield server.user_manager.get_user_or_default()
+async def run_agent_step(server, agent_id, input_messages, actor):
+ """Helper function to run agent step using LettaAgent directly instead of server.send_messages."""
+ agent_loop = LettaAgent(
+ agent_id=agent_id,
+ message_manager=server.message_manager,
+ agent_manager=server.agent_manager,
+ block_manager=server.block_manager,
+ job_manager=server.job_manager,
+ passage_manager=server.passage_manager,
+ actor=actor,
+ step_manager=server.step_manager,
+ telemetry_manager=NoopTelemetryManager(),
+ )
+
+ return await agent_loop.step(
+ input_messages,
+ max_steps=50,
+ use_assistant_message=False,
+ )
+
+
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
-def test_single_path_agent_tool_call_graph(
+@pytest.mark.asyncio
+async def test_single_path_agent_tool_call_graph(
server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool, default_user
):
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
@@ -207,18 +284,11 @@ def test_single_path_agent_tool_call_graph(
# Make agent state
agent_state = setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
- usage_stats = server.send_messages(
- actor=default_user,
+ response = await run_agent_step(
+ server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")],
- )
- messages = [message for step_messages in usage_stats.steps_messages for message in step_messages]
- letta_messages = []
- for m in messages:
- letta_messages += m.to_letta_messages()
-
- response = LettaResponse(
- messages=letta_messages, stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), usage=usage_stats
+ actor=default_user,
)
# Make checks
@@ -299,7 +369,8 @@ def test_check_tool_rules_with_different_models_parametrized(
@pytest.mark.timeout(180)
-def test_claude_initial_tool_rule_enforced(
+@pytest.mark.asyncio
+async def test_claude_initial_tool_rule_enforced(
server,
disable_e2b_api_key,
first_secret_tool,
@@ -325,20 +396,11 @@ def test_claude_initial_tool_rule_enforced(
tool_rules=tool_rules,
)
- usage_stats = server.send_messages(
- actor=default_user,
+ response = await run_agent_step(
+ server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="What is the second secret word?")],
- )
- messages = [m for step in usage_stats.steps_messages for m in step]
- letta_messages = []
- for m in messages:
- letta_messages += m.to_letta_messages()
-
- response = LettaResponse(
- messages=letta_messages,
- stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
- usage=usage_stats,
+ actor=default_user,
)
assert_sanity_checks(response)
@@ -359,7 +421,7 @@ def test_claude_initial_tool_rule_enforced(
# Exponential backoff
if i < 2:
backoff_time = 10 * (2**i)
- time.sleep(backoff_time)
+ await asyncio.sleep(backoff_time)
@pytest.mark.timeout(60)
@@ -370,7 +432,8 @@ def test_claude_initial_tool_rule_enforced(
"tests/configs/llm_model_configs/openai-gpt-4o.json",
],
)
-def test_agent_no_structured_output_with_one_child_tool_parametrized(
+@pytest.mark.asyncio
+async def test_agent_no_structured_output_with_one_child_tool_parametrized(
server,
disable_e2b_api_key,
default_user,
@@ -404,20 +467,11 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized(
tool_rules=tool_rules,
)
- usage_stats = server.send_messages(
- actor=default_user,
+ response = await run_agent_step(
+ server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="hi. run archival memory search")],
- )
- messages = [m for step in usage_stats.steps_messages for m in step]
- letta_messages = []
- for m in messages:
- letta_messages += m.to_letta_messages()
-
- response = LettaResponse(
- messages=letta_messages,
- stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
- usage=usage_stats,
+ actor=default_user,
)
# Run assertions
@@ -448,7 +502,8 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized(
@pytest.mark.timeout(30)
@pytest.mark.parametrize("include_base_tools", [False, True])
-def test_init_tool_rule_always_fails(
+@pytest.mark.asyncio
+async def test_init_tool_rule_always_fails(
server,
disable_e2b_api_key,
auto_error_tool,
@@ -469,17 +524,11 @@ def test_init_tool_rule_always_fails(
include_base_tools=include_base_tools,
)
- usage_stats = server.send_messages(
- actor=default_user,
+ response = await run_agent_step(
+ server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="blah blah blah")],
- )
- messages = [m for step in usage_stats.steps_messages for m in step]
- letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
- response = LettaResponse(
- messages=letta_messages,
- stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
- usage=usage_stats,
+ actor=default_user,
)
assert_invoked_function_call(response.messages, auto_error_tool.name)
@@ -487,7 +536,8 @@ def test_init_tool_rule_always_fails(
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
-def test_continue_tool_rule(server, default_user):
+@pytest.mark.asyncio
+async def test_continue_tool_rule(server, default_user):
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_uuid = str(uuid.uuid4())
@@ -512,17 +562,11 @@ def test_continue_tool_rule(server, default_user):
include_base_tool_rules=False,
)
- usage_stats = server.send_messages(
- actor=default_user,
+ response = await run_agent_step(
+ server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")],
- )
- messages = [m for step in usage_stats.steps_messages for m in step]
- letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
- response = LettaResponse(
- messages=letta_messages,
- stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
- usage=usage_stats,
+ actor=default_user,
)
assert_invoked_function_call(response.messages, "send_message")
@@ -775,3 +819,180 @@ def test_continue_tool_rule(server, default_user):
# assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin"
#
# cleanup(client, agent_uuid=agent_state.id)
+
+
+@pytest.mark.timeout(60)
+@pytest.mark.asyncio
+async def test_single_required_before_exit_tool(server, disable_e2b_api_key, save_data_tool, default_user):
+ """Test that agent is forced to call a single required-before-exit tool before ending."""
+ agent_name = "required_exit_single_tool_agent"
+ config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
+
+ # Set up tools and rules
+ tools = [save_data_tool]
+ tool_rules = [
+ InitToolRule(tool_name="send_message"),
+ RequiredBeforeExitToolRule(tool_name="save_data"),
+ TerminalToolRule(tool_name="send_message"),
+ ]
+
+ # Create agent
+ agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
+
+ # Send message that would normally cause exit
+ response = await run_agent_step(
+ server=server,
+ agent_id=agent_state.id,
+ input_messages=[MessageCreate(role="user", content="Please finish your work and send me a message.")],
+ actor=default_user,
+ )
+
+ # Assertions
+ assert_sanity_checks(response)
+ assert_invoked_function_call(response.messages, "save_data")
+ assert_invoked_function_call(response.messages, "send_message")
+
+ # The key test is that both tools were called - the agent was forced to call save_data
+ # even when it tried to exit early with send_message
+ tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
+ save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
+ send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
+
+ assert len(save_data_calls) >= 1, "save_data should be called at least once"
+ assert len(send_message_calls) >= 1, "send_message should be called at least once"
+
+ print(f"✓ Agent '{agent_name}' successfully called required tool before exit")
+
+
+@pytest.mark.timeout(60)
+@pytest.mark.asyncio
+async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key, save_data_tool, cleanup_temp_files_tool, default_user):
+ """Test that agent calls all required-before-exit tools before ending."""
+ agent_name = "required_exit_multi_tool_agent"
+ config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
+
+ # Set up tools and rules
+ tools = [save_data_tool, cleanup_temp_files_tool]
+ tool_rules = [
+ InitToolRule(tool_name="send_message"),
+ RequiredBeforeExitToolRule(tool_name="save_data"),
+ RequiredBeforeExitToolRule(tool_name="cleanup_temp_files"),
+ TerminalToolRule(tool_name="send_message"),
+ ]
+
+ # Create agent
+ agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
+
+ # Send message that would normally cause exit
+ response = await run_agent_step(
+ server=server,
+ agent_id=agent_state.id,
+ input_messages=[MessageCreate(role="user", content="Complete all necessary tasks and then send me a message.")],
+ actor=default_user,
+ )
+
+ # Assertions
+ assert_sanity_checks(response)
+ assert_invoked_function_call(response.messages, "save_data")
+ assert_invoked_function_call(response.messages, "cleanup_temp_files")
+ assert_invoked_function_call(response.messages, "send_message")
+
+ # Verify that all required tools were eventually called
+ tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
+ save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
+ cleanup_calls = [tc for tc in tool_calls if tc.tool_call.name == "cleanup_temp_files"]
+ send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
+
+ assert len(save_data_calls) >= 1, "save_data should be called at least once"
+ assert len(cleanup_calls) >= 1, "cleanup_temp_files should be called at least once"
+ assert len(send_message_calls) >= 1, "send_message should be called at least once"
+
+ print(f"✓ Agent '{agent_name}' successfully called all required tools before exit")
+
+
+@pytest.mark.timeout(60)
+@pytest.mark.asyncio
+async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key, first_secret_tool, save_data_tool, default_user):
+ """Test required-before-exit rules work alongside other tool rules."""
+ agent_name = "required_exit_with_rules_agent"
+ config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
+
+ # Set up tools and rules - combine with child tool rules
+ tools = [first_secret_tool, save_data_tool]
+ tool_rules = [
+ InitToolRule(tool_name="first_secret_word"),
+ ChildToolRule(tool_name="first_secret_word", children=["send_message"]),
+ RequiredBeforeExitToolRule(tool_name="save_data"),
+ TerminalToolRule(tool_name="send_message"),
+ ]
+
+ # Create agent
+ agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
+
+ # Send message that would trigger tool flow
+ response = await run_agent_step(
+ server=server,
+ agent_id=agent_state.id,
+ input_messages=[MessageCreate(role="user", content="Get the first secret word and then finish up.")],
+ actor=default_user,
+ )
+
+ # Assertions
+ assert_sanity_checks(response)
+ assert_invoked_function_call(response.messages, "first_secret_word")
+ assert_invoked_function_call(response.messages, "save_data")
+ assert_invoked_function_call(response.messages, "send_message")
+
+ # Verify that all tools were called (first_secret_word due to InitToolRule, save_data due to RequiredBeforeExitToolRule)
+ tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
+ first_secret_calls = [tc for tc in tool_calls if tc.tool_call.name == "first_secret_word"]
+ save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
+ send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
+
+ assert len(first_secret_calls) >= 1, "first_secret_word should be called due to InitToolRule"
+ assert len(save_data_calls) >= 1, "save_data should be called due to RequiredBeforeExitToolRule"
+ assert len(send_message_calls) >= 1, "send_message should be called eventually"
+
+ print(f"✓ Agent '{agent_name}' successfully handled mixed tool rules")
+
+
+@pytest.mark.timeout(60)
+@pytest.mark.asyncio
+async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_key, save_data_tool, default_user):
+ """Test that agent can exit normally when required tools are called during regular operation."""
+ agent_name = "required_exit_normal_flow_agent"
+ config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
+
+ # Set up tools and rules
+ tools = [save_data_tool]
+ tool_rules = [
+ InitToolRule(tool_name="save_data"),
+ RequiredBeforeExitToolRule(tool_name="send_message"),
+ TerminalToolRule(tool_name="send_message"),
+ ]
+
+ # Create agent
+ agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
+
+ # Send message that explicitly mentions calling the required tool
+ response = await run_agent_step(
+ server=server,
+ agent_id=agent_state.id,
+ input_messages=[MessageCreate(role="user", content="Please save data and then send me a message when done.")],
+ actor=default_user,
+ )
+
+ # Assertions
+ assert_sanity_checks(response)
+ assert_invoked_function_call(response.messages, "save_data")
+ assert_invoked_function_call(response.messages, "send_message")
+
+ # Should not have excessive tool calls - agent should exit cleanly after requirements are met
+ tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
+ save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
+ send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
+
+ assert len(save_data_calls) == 1, "Should call save_data exactly once"
+ assert len(send_message_calls) == 1, "Should call send_message exactly once"
+
+ print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
diff --git a/tests/test_sources.py b/tests/test_sources.py
index a654ab4d..42a289de 100644
--- a/tests/test_sources.py
+++ b/tests/test_sources.py
@@ -61,6 +61,7 @@ def agent_state(client: LettaSDKClient):
grep_tool = client.tools.list(name="grep")[0]
agent_state = client.agents.create(
+ name="test_sources_agent",
memory_blocks=[
CreateBlock(
label="human",
diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py
index 39015113..bdff9e2e 100644
--- a/tests/test_tool_rule_solver.py
+++ b/tests/test_tool_rule_solver.py
@@ -2,7 +2,14 @@ import pytest
from letta.helpers import ToolRulesSolver
from letta.helpers.tool_rule_solver import ToolRuleValidationError
-from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule
+from letta.schemas.tool_rule import (
+ ChildToolRule,
+ ConditionalToolRule,
+ InitToolRule,
+ MaxCountPerStepToolRule,
+ RequiredBeforeExitToolRule,
+ TerminalToolRule,
+)
# Constants for tool names used in the tests
START_TOOL = "start_tool"
@@ -12,6 +19,9 @@ HELPER_TOOL = "helper_tool"
FINAL_TOOL = "final_tool"
END_TOOL = "end_tool"
UNRECOGNIZED_TOOL = "unrecognized_tool"
+REQUIRED_TOOL_1 = "required_tool_1"
+REQUIRED_TOOL_2 = "required_tool_2"
+SAVE_TOOL = "save_tool"
def test_get_allowed_tool_names_with_init_rules():
@@ -175,3 +185,93 @@ def test_max_count_per_step_tool_rule_resets_on_clear():
solver.clear_tool_history()
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should allow 'start_tool' again after clearing history"
+
+
+def test_required_before_exit_tool_rule_has_required_tools_been_called():
+ """Test has_required_tools_been_called() with no required tools."""
+ solver = ToolRulesSolver(tool_rules=[])
+
+ assert solver.has_required_tools_been_called() is True, "Should return True when no required tools are defined"
+
+
+def test_required_before_exit_tool_rule_single_required_tool():
+ """Test with a single required-before-exit tool."""
+ required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL)
+ solver = ToolRulesSolver(tool_rules=[required_rule])
+
+ assert solver.has_required_tools_been_called() is False, "Should return False when required tool hasn't been called"
+ assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should return list with uncalled required tool"
+
+ solver.register_tool_call(SAVE_TOOL)
+
+ assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
+ assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called"
+
+
+def test_required_before_exit_tool_rule_multiple_required_tools():
+ """Test with multiple required-before-exit tools."""
+ required_rule_1 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_1)
+ required_rule_2 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_2)
+ solver = ToolRulesSolver(tool_rules=[required_rule_1, required_rule_2])
+
+ assert solver.has_required_tools_been_called() is False, "Should return False when no required tools have been called"
+ uncalled_tools = solver.get_uncalled_required_tools()
+ assert set(uncalled_tools) == {REQUIRED_TOOL_1, REQUIRED_TOOL_2}, "Should return both uncalled required tools"
+
+ # Call first required tool
+ solver.register_tool_call(REQUIRED_TOOL_1)
+
+ assert solver.has_required_tools_been_called() is False, "Should return False when only one required tool has been called"
+ assert solver.get_uncalled_required_tools() == [REQUIRED_TOOL_2], "Should return remaining uncalled required tool"
+
+ # Call second required tool
+ solver.register_tool_call(REQUIRED_TOOL_2)
+
+ assert solver.has_required_tools_been_called() is True, "Should return True when all required tools have been called"
+ assert solver.get_uncalled_required_tools() == [], "Should return empty list when all required tools have been called"
+
+
+def test_required_before_exit_tool_rule_mixed_with_other_tools():
+ """Test required-before-exit tools mixed with other tool calls."""
+ required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL)
+ solver = ToolRulesSolver(tool_rules=[required_rule])
+
+ # Call other tools first
+ solver.register_tool_call(START_TOOL)
+ solver.register_tool_call(HELPER_TOOL)
+
+ assert solver.has_required_tools_been_called() is False, "Should return False even after calling other tools"
+ assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should still show required tool as uncalled"
+
+ # Call required tool
+ solver.register_tool_call(SAVE_TOOL)
+
+ assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
+ assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called"
+
+
+def test_required_before_exit_tool_rule_is_terminal():
+ """Test that required-before-exit tools are considered terminal tools."""
+ required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL)
+ terminal_rule = TerminalToolRule(tool_name=END_TOOL)
+ solver = ToolRulesSolver(tool_rules=[required_rule, terminal_rule])
+
+ assert solver.is_terminal_tool(SAVE_TOOL) is True, "Required-before-exit tool should be considered terminal"
+ assert solver.is_terminal_tool(END_TOOL) is True, "Regular terminal tool should still be considered terminal"
+ assert solver.is_terminal_tool(START_TOOL) is False, "Non-terminal tool should not be considered terminal"
+
+
+def test_required_before_exit_tool_rule_clear_history():
+ """Test that clearing history resets the required tools state."""
+ required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL)
+ solver = ToolRulesSolver(tool_rules=[required_rule])
+
+ # Call required tool
+ solver.register_tool_call(SAVE_TOOL)
+ assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
+
+ # Clear history
+ solver.clear_tool_history()
+
+ assert solver.has_required_tools_been_called() is False, "Should return False after clearing history"
+ assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should show required tool as uncalled after clearing history"