diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py
index 7d5149c3..6d40704f 100644
--- a/letta/agents/letta_agent.py
+++ b/letta/agents/letta_agent.py
@@ -33,6 +33,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.provider_trace import ProviderTraceCreate
+from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
@@ -170,12 +171,14 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
- request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
- current_in_context_messages,
- new_in_context_messages,
- agent_state,
- llm_client,
- tool_rules_solver,
+ request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
+ await self._build_and_request_from_llm(
+ current_in_context_messages,
+ new_in_context_messages,
+ agent_state,
+ llm_client,
+ tool_rules_solver,
+ )
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -220,6 +223,7 @@ class LettaAgent(BaseAgent):
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
+ valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
@@ -323,8 +327,10 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
- request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
- current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
+ request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
+ await self._build_and_request_from_llm(
+ current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
+ )
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -363,6 +369,7 @@ class LettaAgent(BaseAgent):
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
+ valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
@@ -460,15 +467,17 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
- request_data, stream, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm_streaming(
- first_chunk,
- agent_step_span,
- request_start_timestamp_ns,
- current_in_context_messages,
- new_in_context_messages,
- agent_state,
- llm_client,
- tool_rules_solver,
+ request_data, stream, current_in_context_messages, new_in_context_messages, valid_tool_names = (
+ await self._build_and_request_from_llm_streaming(
+ first_chunk,
+ agent_step_span,
+ request_start_timestamp_ns,
+ current_in_context_messages,
+ new_in_context_messages,
+ agent_state,
+ llm_client,
+ tool_rules_solver,
+ )
)
log_event("agent.stream.llm_response.received") # [3^]
@@ -523,6 +532,7 @@ class LettaAgent(BaseAgent):
reasoning_content = interface.get_reasoning_content()
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
+ valid_tool_names,
agent_state,
tool_rules_solver,
UsageStatistics(
@@ -611,12 +621,12 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
- ) -> Tuple[Dict, Dict, List[Message], List[Message]]:
+ ) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]]:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
# Create LLM request data
- request_data = await self._create_llm_request_data_async(
+ request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
@@ -631,12 +641,7 @@ class LettaAgent(BaseAgent):
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Attempt LLM request
- return (
- request_data,
- response,
- current_in_context_messages,
- new_in_context_messages,
- )
+ return (request_data, response, current_in_context_messages, new_in_context_messages, valid_tool_names)
except Exception as e:
if attempt == self.max_summarization_retries:
@@ -664,12 +669,12 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
- ) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message]]:
+ ) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str]]:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
# Create LLM request data
- request_data = await self._create_llm_request_data_async(
+ request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
@@ -688,6 +693,7 @@ class LettaAgent(BaseAgent):
await llm_client.stream_async(request_data, agent_state.llm_config),
current_in_context_messages,
new_in_context_messages,
+ valid_tool_names,
)
except Exception as e:
@@ -770,7 +776,7 @@ class LettaAgent(BaseAgent):
in_context_messages: List[Message],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
- ) -> dict:
+ ) -> Tuple[dict, List[str]]:
self.num_messages, self.num_archival_memories = await asyncio.gather(
(
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
@@ -827,17 +833,21 @@ class LettaAgent(BaseAgent):
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
)
- return llm_client.build_request_data(
- in_context_messages,
- agent_state.llm_config,
- allowed_tools,
- force_tool_call,
+ return (
+ llm_client.build_request_data(
+ in_context_messages,
+ agent_state.llm_config,
+ allowed_tools,
+ force_tool_call,
+ ),
+ valid_tool_names,
)
@trace_method
async def _handle_ai_response(
self,
tool_call: ToolCall,
+ valid_tool_names: List[str],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
usage: UsageStatistics,
@@ -853,8 +863,10 @@ class LettaAgent(BaseAgent):
This might yield additional SSE tokens if we do stalling.
At the end, set self._continue_execution accordingly.
"""
+ # Check if the called tool is allowed by tool name:
tool_call_name = tool_call.function.name
tool_call_args_str = tool_call.function.arguments
+
# Temp hack to gracefully handle parallel tool calling attempt, only take first one
if "}{" in tool_call_args_str:
tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}"
@@ -893,14 +905,21 @@ class LettaAgent(BaseAgent):
tool_call_id=tool_call_id,
request_heartbeat=request_heartbeat,
)
-
- tool_execution_result = await self._execute_tool(
- tool_name=tool_call_name,
- tool_args=tool_args,
- agent_state=agent_state,
- agent_step_span=agent_step_span,
- step_id=step_id,
- )
+ if tool_call_name not in valid_tool_names:
+ base_error_message = f"[ToolConstraintError] Cannot call {tool_call_name}, valid tools to call include: {valid_tool_names}."
+ violated_rule_messages = tool_rules_solver.guess_rule_violation(tool_call_name)
+ if violated_rule_messages:
+ bullet_points = "\n".join(f"\t- {msg}" for msg in violated_rule_messages)
+ base_error_message += f"\n** Hint: Possible rules that were violated:\n{bullet_points}"
+ tool_execution_result = ToolExecutionResult(status="error", func_return=base_error_message)
+ else:
+ tool_execution_result = await self._execute_tool(
+ tool_name=tool_call_name,
+ tool_args=tool_args,
+ agent_state=agent_state,
+ agent_step_span=agent_step_span,
+ step_id=step_id,
+ )
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
)
diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py
index caecd271..acd35c8a 100644
--- a/letta/helpers/tool_rule_solver.py
+++ b/letta/helpers/tool_rule_solver.py
@@ -172,6 +172,40 @@ class ToolRulesSolver(BaseModel):
)
return None
+ def guess_rule_violation(self, tool_name: str) -> List[str]:
+ """
+ Check if the given tool name or the previous tool in history matches any tool rule,
+ and return rendered prompt templates for matching rules.
+
+ Args:
+ tool_name: The name of the tool to check for rule violations
+
+ Returns:
+ List of rendered prompt templates from matching tool rules
+ """
+ violated_rules = []
+
+ # Get the previous tool from history if it exists
+ previous_tool = self.tool_call_history[-1] if self.tool_call_history else None
+
+ # Check all tool rules for matches
+ all_rules = (
+ self.init_tool_rules
+ + self.continue_tool_rules
+ + self.child_based_tool_rules
+ + self.parent_tool_rules
+ + self.terminal_tool_rules
+ )
+
+ for rule in all_rules:
+ # Check if the current tool name or previous tool matches this rule's tool_name
+ if rule.tool_name == tool_name or (previous_tool and rule.tool_name == previous_tool):
+ rendered_prompt = rule.render_prompt()
+ if rendered_prompt:
+ violated_rules.append(rendered_prompt)
+
+ return violated_rules
+
@staticmethod
def validate_conditional_tool(rule: ConditionalToolRule):
"""
diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py
index 7efa07c0..94dc4978 100644
--- a/letta/schemas/tool_rule.py
+++ b/letta/schemas/tool_rule.py
@@ -25,19 +25,24 @@ class BaseToolRule(LettaBase):
def render_prompt(self) -> Optional[str]:
"""Render the prompt template with this rule's attributes."""
- if not self.prompt_template:
+ template_to_use = self.prompt_template or self._get_default_template()
+ if not template_to_use:
return None
try:
- template = Template(self.prompt_template)
+ template = Template(template_to_use)
return template.render(**self.model_dump())
except Exception as e:
logger.warning(
f"Failed to render prompt template for tool rule '{self.tool_name}' (type: {self.type}). "
- f"Template: '{self.prompt_template}'. Error: {e}"
+ f"Template: '{template_to_use}'. Error: {e}"
)
return None
+ def _get_default_template(self) -> Optional[str]:
+ """Get the default template for this rule type. Override in subclasses."""
+ return None
+
class ChildToolRule(BaseToolRule):
"""
@@ -46,11 +51,18 @@ class ChildToolRule(BaseToolRule):
type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools
children: List[str] = Field(..., description="The children tools that can be invoked.")
+ prompt_template: Optional[str] = Field(
+ default="After using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}",
+ description="Optional Jinja2 template for generating agent prompt about this tool rule.",
+ )
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
last_tool = tool_call_history[-1] if tool_call_history else None
return set(self.children) if last_tool == self.tool_name else available_tools
+ def _get_default_template(self) -> Optional[str]:
+ return "After using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}"
+
class ParentToolRule(BaseToolRule):
"""
@@ -59,11 +71,18 @@ class ParentToolRule(BaseToolRule):
type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool
children: List[str] = Field(..., description="The children tools that can be invoked.")
+ prompt_template: Optional[str] = Field(
+ default="{{ children | join(', ') }} can only be used after {{ tool_name }}",
+ description="Optional Jinja2 template for generating agent prompt about this tool rule.",
+ )
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
last_tool = tool_call_history[-1] if tool_call_history else None
return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children)
+ def _get_default_template(self) -> Optional[str]:
+ return "{{ children | join(', ') }} can only be used after {{ tool_name }}"
+
class ConditionalToolRule(BaseToolRule):
"""
@@ -74,6 +93,10 @@ class ConditionalToolRule(BaseToolRule):
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")
+ prompt_template: Optional[str] = Field(
+ default="{{ tool_name }} will determine which tool to use next based on its output",
+ description="Optional Jinja2 template for generating agent prompt about this tool rule.",
+ )
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
"""Determine valid tools based on function output mapping."""
@@ -119,6 +142,9 @@ class ConditionalToolRule(BaseToolRule):
else: # Assume string
return str(function_output) == str(key)
+ def _get_default_template(self) -> Optional[str]:
+ return "{{ tool_name }} will determine which tool to use next based on its output"
+
class InitToolRule(BaseToolRule):
"""
@@ -134,6 +160,13 @@ class TerminalToolRule(BaseToolRule):
"""
type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop
+ prompt_template: Optional[str] = Field(
+ default="{{ tool_name }} ends the conversation when called",
+ description="Optional Jinja2 template for generating agent prompt about this tool rule.",
+ )
+
+ def _get_default_template(self) -> Optional[str]:
+ return "{{ tool_name }} ends the conversation when called"
class ContinueToolRule(BaseToolRule):
@@ -142,6 +175,10 @@ class ContinueToolRule(BaseToolRule):
"""
type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop
+ prompt_template: Optional[str] = Field(
+ default="{{ tool_name }} requires continuing the conversation when called",
+ description="Optional Jinja2 template for generating agent prompt about this tool rule.",
+ )
class MaxCountPerStepToolRule(BaseToolRule):
@@ -166,6 +203,9 @@ class MaxCountPerStepToolRule(BaseToolRule):
return available_tools
+ def _get_default_template(self) -> Optional[str]:
+ return "{{ tool_name }}: max {{ max_count_limit }} use(s) per turn"
+
ToolRule = Annotated[
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule],
diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py
index 943e779b..916f1863 100644
--- a/tests/test_sdk_client.py
+++ b/tests/test_sdk_client.py
@@ -520,51 +520,6 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState):
# assert len(responses) == len(messages), "Not all messages were processed"
-def test_send_message_async(client: LettaSDKClient, agent: AgentState):
- """
- Test that we can send a message asynchronously and retrieve the messages, along with usage statistics
- """
- test_message = "This is a test message, respond to the user with a sentence."
- run = client.agents.messages.create_async(
- agent_id=agent.id,
- messages=[
- MessageCreate(
- role="user",
- content=test_message,
- ),
- ],
- use_assistant_message=False,
- )
- assert run.id is not None
- assert run.status == "created"
-
- # Wait for the job to complete, cancel it if takes over 10 seconds
- start_time = time.time()
- while run.status == "created":
- time.sleep(1)
- run = client.runs.retrieve(run_id=run.id)
- print(f"Run status: {run.status}")
- if time.time() - start_time > 10:
- pytest.fail("Run took too long to complete")
-
- print(f"Run completed in {time.time() - start_time} seconds, run={run}")
- assert run.status == "completed"
-
- # Get messages for the job
- messages = client.runs.messages.list(run_id=run.id)
- assert len(messages) >= 2 # At least assistant response
-
- # Check filters
- assistant_messages = client.runs.messages.list(run_id=run.id, role="assistant")
- assert len(assistant_messages) > 0
- tool_messages = client.runs.messages.list(run_id=run.id, role="tool")
- assert len(tool_messages) > 0
-
- # specific_tool_messages = [message for message in client.runs.list_run_messages(run_id=run.id) if isinstance(message, ToolCallMessage)]
- # assert specific_tool_messages[0].tool_call.name == "send_message"
- # assert len(specific_tool_messages) > 0
-
-
def test_agent_creation(client: LettaSDKClient):
"""Test that block IDs are properly attached when creating an agent."""
sleeptime_agent_system = """
diff --git a/tests/test_server.py b/tests/test_server.py
index 4f6fdc25..6a13f672 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -901,18 +901,6 @@ async def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, u
assert not result.stderr
-def test_composio_client_simple(server):
- apps = server.get_composio_apps()
- # Assert there's some amount of apps returned
- assert len(apps) > 0
-
- app = apps[0]
- actions = server.get_composio_actions_from_app_name(composio_app_name=app.name)
-
- # Assert there's some amount of actions
- assert len(actions) > 0
-
-
async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = user