feat: add client side tool calling support (#5313)

This commit is contained in:
cthomas
2025-10-10 13:50:55 -07:00
committed by Caren Thomas
parent f79a47dbc1
commit 3128b5e126
5 changed files with 86 additions and 32 deletions

View File

@@ -168,15 +168,19 @@ async def _prepare_in_context_messages_no_persist_async(
# Check for approval-related message validation
if len(input_messages) == 1 and input_messages[0].type == "approval":
# User is trying to send an approval response
approval_message = input_messages[0].approvals[0]
if current_in_context_messages[-1].role != "approval":
raise ValueError(
"Cannot process approval response: No tool call is currently awaiting approval. "
"Please send a regular message to interact with the agent."
)
if input_messages[0].approval_request_id != current_in_context_messages[-1].id:
if (
approval_message.tool_call_id != current_in_context_messages[-1].id
and approval_message.tool_call_id != current_in_context_messages[-1].tool_calls[0].id
):
raise ValueError(
f"Invalid approval request ID. Expected '{current_in_context_messages[-1].id}' "
f"but received '{input_messages[0].approval_request_id}'."
f"Invalid tool call ID. Expected '{current_in_context_messages[-1].tool_calls[0].id}' "
f"but received '{approval_message.tool_call_id}'."
)
new_in_context_messages = create_approval_response_message_from_input(agent_state=agent_state, input_message=input_messages[0])
else:

View File

@@ -19,16 +19,17 @@ from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
from letta.errors import ContextWindowExceededError, LLMError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessage, MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
@@ -397,6 +398,9 @@ class LettaAgentV3(LettaAgentV2):
is_approval=approval_response.approve if approval_response is not None else False,
is_denial=(approval_response.approve == False) if approval_response is not None else False,
denial_reason=approval_response.denial_reason if approval_response is not None else None,
tool_return=approval_response.approvals[0]
if approval_response and approval_response.approvals and isinstance(approval_response.approvals[0], ToolReturn)
else None,
)
# NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op")
@@ -542,6 +546,7 @@ class LettaAgentV3(LettaAgentV2):
is_approval: bool | None = None,
is_denial: bool | None = None,
denial_reason: str | None = None,
tool_return: ToolReturn | None = None,
) -> tuple[list[Message], bool, LettaStopReason | None]:
"""
Handle the final AI response once streaming completes, execute / validate the
@@ -553,29 +558,45 @@ class LettaAgentV3(LettaAgentV2):
else:
tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
if is_denial:
if is_denial or tool_return is not None:
continue_stepping = True
stop_reason = None
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call.function.name,
function_arguments={},
tool_execution_result=ToolExecutionResult(status="error"),
tool_call_id=tool_call_id,
function_response=f"Error: request to call tool denied. User reason: {denial_reason}",
timezone=agent_state.timezone,
continue_stepping=continue_stepping,
# NOTE: we may need to change this to not have a "heartbeat" prefix for v3?
heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.",
reasoning_content=None,
pre_computed_assistant_message_id=None,
step_id=step_id,
run_id=run_id,
is_approval_response=True,
force_set_request_heartbeat=False,
add_heartbeat_on_continue=False,
)
if tool_return is not None:
tool_call_messages = [
Message(
role=MessageRole.tool,
content=[TextContent(text=tool_return.func_response)],
agent_id=agent_state.id,
model=agent_state.llm_config.model,
tool_calls=[],
tool_call_id=tool_return.tool_call_id,
created_at=get_utc_time(),
tool_returns=[tool_return],
run_id=run_id,
step_id=step_id,
)
]
else:
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call.function.name,
function_arguments={},
tool_execution_result=ToolExecutionResult(status="error"),
tool_call_id=tool_call_id,
function_response=f"Error: request to call tool denied. User reason: {denial_reason}",
timezone=agent_state.timezone,
continue_stepping=continue_stepping,
# NOTE: we may need to change this to not have a "heartbeat" prefix for v3?
heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.",
reasoning_content=None,
pre_computed_assistant_message_id=None,
step_id=step_id,
run_id=run_id,
is_approval_response=True,
force_set_request_heartbeat=False,
add_heartbeat_on_continue=False,
)
messages_to_persist = (initial_messages or []) + tool_call_messages
# Set run_id on all messages before persisting

View File

@@ -33,6 +33,7 @@ from letta.schemas.letta_message import (
SystemMessage,
ToolCall,
ToolCallMessage,
ToolReturn as LettaToolReturn,
ToolReturnMessage,
UserMessage,
)
@@ -344,17 +345,29 @@ class Message(BaseMessage):
messages.append(approval_request_message)
else:
if self.approvals:
first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)][0]
first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)]
def maybe_convert_tool_return_message(maybe_tool_return: ToolReturn):
if isinstance(maybe_tool_return, LettaToolReturn):
return LettaToolReturn(
tool_call_id=maybe_tool_return.tool_call_id,
status=maybe_tool_return.status,
tool_return=maybe_tool_return.tool_return,
stdout=maybe_tool_return.stdout,
stderr=maybe_tool_return.stderr,
)
return maybe_tool_return
approval_response_message = ApprovalResponseMessage(
id=self.id,
date=self.created_at,
otid=self.otid,
approvals=self.approvals,
approvals=[maybe_convert_tool_return_message(approval) for approval in self.approvals],
run_id=self.run_id,
# TODO: temporary populate these fields for backwards compatibility
approve=first_approval.approve,
approval_request_id=first_approval.tool_call_id,
reason=first_approval.reason,
approve=first_approval[0].approve if first_approval else None,
approval_request_id=first_approval[0].tool_call_id if first_approval else None,
reason=first_approval[0].reason if first_approval else None,
)
else:
approval_response_message = ApprovalResponseMessage(

View File

@@ -27,6 +27,7 @@ from letta.otel.metric_registry import MetricRegistry
from letta.otel.tracing import tracer
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import ToolReturn as LettaToolReturn
from letta.schemas.letta_message_content import (
OmittedReasoningContent,
ReasoningContent,
@@ -169,6 +170,20 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti
def create_approval_response_message_from_input(agent_state: AgentState, input_message: ApprovalCreate) -> List[Message]:
def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn):
if isinstance(maybe_tool_return, LettaToolReturn):
packaged_function_response = package_function_response(
maybe_tool_return.status == "success", maybe_tool_return.tool_return, agent_state.timezone
)
return ToolReturn(
tool_call_id=maybe_tool_return.tool_call_id,
status=maybe_tool_return.status,
func_response=packaged_function_response,
stdout=maybe_tool_return.stdout,
stderr=maybe_tool_return.stderr,
)
return maybe_tool_return
return [
Message(
role=MessageRole.approval,
@@ -177,6 +192,7 @@ def create_approval_response_message_from_input(agent_state: AgentState, input_m
approval_request_id=input_message.approval_request_id,
approve=input_message.approve,
denial_reason=input_message.reason,
approvals=[maybe_convert_tool_return_message(approval) for approval in input_message.approvals],
)
]

View File

@@ -184,7 +184,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
messages=USER_MESSAGE_TEST_APPROVAL,
)
with pytest.raises(ApiError, match="Invalid approval request ID"):
with pytest.raises(ApiError, match="Invalid tool call ID"):
client.agents.messages.create(
agent_id=agent.id,
messages=[ApprovalCreate(approve=True, approval_request_id=FAKE_REQUEST_ID)],