feat: Add errors when tool call violates tool rules (#2766)
This commit is contained in:
@@ -33,6 +33,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
||||
@@ -170,12 +171,14 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
)
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
@@ -220,6 +223,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
@@ -323,8 +327,10 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
|
||||
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
|
||||
)
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
@@ -363,6 +369,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
@@ -460,15 +467,17 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
request_data, stream, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm_streaming(
|
||||
first_chunk,
|
||||
agent_step_span,
|
||||
request_start_timestamp_ns,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
request_data, stream, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm_streaming(
|
||||
first_chunk,
|
||||
agent_step_span,
|
||||
request_start_timestamp_ns,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
)
|
||||
)
|
||||
log_event("agent.stream.llm_response.received") # [3^]
|
||||
|
||||
@@ -523,6 +532,7 @@ class LettaAgent(BaseAgent):
|
||||
reasoning_content = interface.get_reasoning_content()
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call,
|
||||
valid_tool_names,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
UsageStatistics(
|
||||
@@ -611,12 +621,12 @@ class LettaAgent(BaseAgent):
|
||||
agent_state: AgentState,
|
||||
llm_client: LLMClientBase,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
) -> Tuple[Dict, Dict, List[Message], List[Message]]:
|
||||
) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]]:
|
||||
for attempt in range(self.max_summarization_retries + 1):
|
||||
try:
|
||||
log_event("agent.stream_no_tokens.messages.refreshed")
|
||||
# Create LLM request data
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
agent_state=agent_state,
|
||||
@@ -631,12 +641,7 @@ class LettaAgent(BaseAgent):
|
||||
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
|
||||
)
|
||||
# Attempt LLM request
|
||||
return (
|
||||
request_data,
|
||||
response,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
)
|
||||
return (request_data, response, current_in_context_messages, new_in_context_messages, valid_tool_names)
|
||||
|
||||
except Exception as e:
|
||||
if attempt == self.max_summarization_retries:
|
||||
@@ -664,12 +669,12 @@ class LettaAgent(BaseAgent):
|
||||
agent_state: AgentState,
|
||||
llm_client: LLMClientBase,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message]]:
|
||||
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str]]:
|
||||
for attempt in range(self.max_summarization_retries + 1):
|
||||
try:
|
||||
log_event("agent.stream_no_tokens.messages.refreshed")
|
||||
# Create LLM request data
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
agent_state=agent_state,
|
||||
@@ -688,6 +693,7 @@ class LettaAgent(BaseAgent):
|
||||
await llm_client.stream_async(request_data, agent_state.llm_config),
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
valid_tool_names,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -770,7 +776,7 @@ class LettaAgent(BaseAgent):
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
) -> dict:
|
||||
) -> Tuple[dict, List[str]]:
|
||||
self.num_messages, self.num_archival_memories = await asyncio.gather(
|
||||
(
|
||||
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
@@ -827,17 +833,21 @@ class LettaAgent(BaseAgent):
|
||||
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
|
||||
)
|
||||
|
||||
return llm_client.build_request_data(
|
||||
in_context_messages,
|
||||
agent_state.llm_config,
|
||||
allowed_tools,
|
||||
force_tool_call,
|
||||
return (
|
||||
llm_client.build_request_data(
|
||||
in_context_messages,
|
||||
agent_state.llm_config,
|
||||
allowed_tools,
|
||||
force_tool_call,
|
||||
),
|
||||
valid_tool_names,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def _handle_ai_response(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
valid_tool_names: List[str],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
usage: UsageStatistics,
|
||||
@@ -853,8 +863,10 @@ class LettaAgent(BaseAgent):
|
||||
This might yield additional SSE tokens if we do stalling.
|
||||
At the end, set self._continue_execution accordingly.
|
||||
"""
|
||||
# Check if the called tool is allowed by tool name:
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_args_str = tool_call.function.arguments
|
||||
|
||||
# Temp hack to gracefully handle parallel tool calling attempt, only take first one
|
||||
if "}{" in tool_call_args_str:
|
||||
tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}"
|
||||
@@ -893,14 +905,21 @@ class LettaAgent(BaseAgent):
|
||||
tool_call_id=tool_call_id,
|
||||
request_heartbeat=request_heartbeat,
|
||||
)
|
||||
|
||||
tool_execution_result = await self._execute_tool(
|
||||
tool_name=tool_call_name,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
step_id=step_id,
|
||||
)
|
||||
if tool_call_name not in valid_tool_names:
|
||||
base_error_message = f"[ToolConstraintError] Cannot call {tool_call_name}, valid tools to call include: {valid_tool_names}."
|
||||
violated_rule_messages = tool_rules_solver.guess_rule_violation(tool_call_name)
|
||||
if violated_rule_messages:
|
||||
bullet_points = "\n".join(f"\t- {msg}" for msg in violated_rule_messages)
|
||||
base_error_message += f"\n** Hint: Possible rules that were violated:\n{bullet_points}"
|
||||
tool_execution_result = ToolExecutionResult(status="error", func_return=base_error_message)
|
||||
else:
|
||||
tool_execution_result = await self._execute_tool(
|
||||
tool_name=tool_call_name,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
step_id=step_id,
|
||||
)
|
||||
log_telemetry(
|
||||
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
|
||||
)
|
||||
|
||||
@@ -172,6 +172,40 @@ class ToolRulesSolver(BaseModel):
|
||||
)
|
||||
return None
|
||||
|
||||
def guess_rule_violation(self, tool_name: str) -> List[str]:
|
||||
"""
|
||||
Check if the given tool name or the previous tool in history matches any tool rule,
|
||||
and return rendered prompt templates for matching rules.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to check for rule violations
|
||||
|
||||
Returns:
|
||||
List of rendered prompt templates from matching tool rules
|
||||
"""
|
||||
violated_rules = []
|
||||
|
||||
# Get the previous tool from history if it exists
|
||||
previous_tool = self.tool_call_history[-1] if self.tool_call_history else None
|
||||
|
||||
# Check all tool rules for matches
|
||||
all_rules = (
|
||||
self.init_tool_rules
|
||||
+ self.continue_tool_rules
|
||||
+ self.child_based_tool_rules
|
||||
+ self.parent_tool_rules
|
||||
+ self.terminal_tool_rules
|
||||
)
|
||||
|
||||
for rule in all_rules:
|
||||
# Check if the current tool name or previous tool matches this rule's tool_name
|
||||
if rule.tool_name == tool_name or (previous_tool and rule.tool_name == previous_tool):
|
||||
rendered_prompt = rule.render_prompt()
|
||||
if rendered_prompt:
|
||||
violated_rules.append(rendered_prompt)
|
||||
|
||||
return violated_rules
|
||||
|
||||
@staticmethod
|
||||
def validate_conditional_tool(rule: ConditionalToolRule):
|
||||
"""
|
||||
|
||||
@@ -25,19 +25,24 @@ class BaseToolRule(LettaBase):
|
||||
|
||||
def render_prompt(self) -> Optional[str]:
|
||||
"""Render the prompt template with this rule's attributes."""
|
||||
if not self.prompt_template:
|
||||
template_to_use = self.prompt_template or self._get_default_template()
|
||||
if not template_to_use:
|
||||
return None
|
||||
|
||||
try:
|
||||
template = Template(self.prompt_template)
|
||||
template = Template(template_to_use)
|
||||
return template.render(**self.model_dump())
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to render prompt template for tool rule '{self.tool_name}' (type: {self.type}). "
|
||||
f"Template: '{self.prompt_template}'. Error: {e}"
|
||||
f"Template: '{template_to_use}'. Error: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_default_template(self) -> Optional[str]:
|
||||
"""Get the default template for this rule type. Override in subclasses."""
|
||||
return None
|
||||
|
||||
|
||||
class ChildToolRule(BaseToolRule):
|
||||
"""
|
||||
@@ -46,11 +51,18 @@ class ChildToolRule(BaseToolRule):
|
||||
|
||||
type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools
|
||||
children: List[str] = Field(..., description="The children tools that can be invoked.")
|
||||
prompt_template: Optional[str] = Field(
|
||||
default="<tool_constraint>After using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}</tool_constraint>",
|
||||
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
|
||||
)
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||
last_tool = tool_call_history[-1] if tool_call_history else None
|
||||
return set(self.children) if last_tool == self.tool_name else available_tools
|
||||
|
||||
def _get_default_template(self) -> Optional[str]:
|
||||
return "<tool_constraint>After using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}</tool_constraint>"
|
||||
|
||||
|
||||
class ParentToolRule(BaseToolRule):
|
||||
"""
|
||||
@@ -59,11 +71,18 @@ class ParentToolRule(BaseToolRule):
|
||||
|
||||
type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool
|
||||
children: List[str] = Field(..., description="The children tools that can be invoked.")
|
||||
prompt_template: Optional[str] = Field(
|
||||
default="<tool_constraint>{{ children | join(', ') }} can only be used after {{ tool_name }}</tool_constraint>",
|
||||
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
|
||||
)
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||
last_tool = tool_call_history[-1] if tool_call_history else None
|
||||
return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children)
|
||||
|
||||
def _get_default_template(self) -> Optional[str]:
|
||||
return "<tool_constraint>{{ children | join(', ') }} can only be used after {{ tool_name }}</tool_constraint>"
|
||||
|
||||
|
||||
class ConditionalToolRule(BaseToolRule):
|
||||
"""
|
||||
@@ -74,6 +93,10 @@ class ConditionalToolRule(BaseToolRule):
|
||||
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
|
||||
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
|
||||
require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")
|
||||
prompt_template: Optional[str] = Field(
|
||||
default="<tool_constraint>{{ tool_name }} will determine which tool to use next based on its output</tool_constraint>",
|
||||
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
|
||||
)
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||
"""Determine valid tools based on function output mapping."""
|
||||
@@ -119,6 +142,9 @@ class ConditionalToolRule(BaseToolRule):
|
||||
else: # Assume string
|
||||
return str(function_output) == str(key)
|
||||
|
||||
def _get_default_template(self) -> Optional[str]:
|
||||
return "<tool_constraint>{{ tool_name }} will determine which tool to use next based on its output</tool_constraint>"
|
||||
|
||||
|
||||
class InitToolRule(BaseToolRule):
|
||||
"""
|
||||
@@ -134,6 +160,13 @@ class TerminalToolRule(BaseToolRule):
|
||||
"""
|
||||
|
||||
type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop
|
||||
prompt_template: Optional[str] = Field(
|
||||
default="<tool_constraint>{{ tool_name }} ends the conversation when called</tool_constraint>",
|
||||
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
|
||||
)
|
||||
|
||||
def _get_default_template(self) -> Optional[str]:
|
||||
return "<tool_constraint>{{ tool_name }} ends the conversation when called</tool_constraint>"
|
||||
|
||||
|
||||
class ContinueToolRule(BaseToolRule):
|
||||
@@ -142,6 +175,10 @@ class ContinueToolRule(BaseToolRule):
|
||||
"""
|
||||
|
||||
type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop
|
||||
prompt_template: Optional[str] = Field(
|
||||
default="<tool_constraint>{{ tool_name }} requires continuing the conversation when called</tool_constraint>",
|
||||
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
|
||||
)
|
||||
|
||||
|
||||
class MaxCountPerStepToolRule(BaseToolRule):
|
||||
@@ -166,6 +203,9 @@ class MaxCountPerStepToolRule(BaseToolRule):
|
||||
|
||||
return available_tools
|
||||
|
||||
def _get_default_template(self) -> Optional[str]:
|
||||
return "<tool_constraint>{{ tool_name }}: max {{ max_count_limit }} use(s) per turn</tool_constraint>"
|
||||
|
||||
|
||||
ToolRule = Annotated[
|
||||
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule],
|
||||
|
||||
@@ -520,51 +520,6 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState):
|
||||
# assert len(responses) == len(messages), "Not all messages were processed"
|
||||
|
||||
|
||||
def test_send_message_async(client: LettaSDKClient, agent: AgentState):
|
||||
"""
|
||||
Test that we can send a message asynchronously and retrieve the messages, along with usage statistics
|
||||
"""
|
||||
test_message = "This is a test message, respond to the user with a sentence."
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=test_message,
|
||||
),
|
||||
],
|
||||
use_assistant_message=False,
|
||||
)
|
||||
assert run.id is not None
|
||||
assert run.status == "created"
|
||||
|
||||
# Wait for the job to complete, cancel it if takes over 10 seconds
|
||||
start_time = time.time()
|
||||
while run.status == "created":
|
||||
time.sleep(1)
|
||||
run = client.runs.retrieve(run_id=run.id)
|
||||
print(f"Run status: {run.status}")
|
||||
if time.time() - start_time > 10:
|
||||
pytest.fail("Run took too long to complete")
|
||||
|
||||
print(f"Run completed in {time.time() - start_time} seconds, run={run}")
|
||||
assert run.status == "completed"
|
||||
|
||||
# Get messages for the job
|
||||
messages = client.runs.messages.list(run_id=run.id)
|
||||
assert len(messages) >= 2 # At least assistant response
|
||||
|
||||
# Check filters
|
||||
assistant_messages = client.runs.messages.list(run_id=run.id, role="assistant")
|
||||
assert len(assistant_messages) > 0
|
||||
tool_messages = client.runs.messages.list(run_id=run.id, role="tool")
|
||||
assert len(tool_messages) > 0
|
||||
|
||||
# specific_tool_messages = [message for message in client.runs.list_run_messages(run_id=run.id) if isinstance(message, ToolCallMessage)]
|
||||
# assert specific_tool_messages[0].tool_call.name == "send_message"
|
||||
# assert len(specific_tool_messages) > 0
|
||||
|
||||
|
||||
def test_agent_creation(client: LettaSDKClient):
|
||||
"""Test that block IDs are properly attached when creating an agent."""
|
||||
sleeptime_agent_system = """
|
||||
|
||||
@@ -901,18 +901,6 @@ async def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, u
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_composio_client_simple(server):
|
||||
apps = server.get_composio_apps()
|
||||
# Assert there's some amount of apps returned
|
||||
assert len(apps) > 0
|
||||
|
||||
app = apps[0]
|
||||
actions = server.get_composio_actions_from_app_name(composio_app_name=app.name)
|
||||
|
||||
# Assert there's some amount of actions
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
actor = user
|
||||
|
||||
Reference in New Issue
Block a user