feat: add client side tool calling support (#5313)
This commit is contained in:
@@ -168,15 +168,19 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
# Check for approval-related message validation
|
||||
if len(input_messages) == 1 and input_messages[0].type == "approval":
|
||||
# User is trying to send an approval response
|
||||
approval_message = input_messages[0].approvals[0]
|
||||
if current_in_context_messages[-1].role != "approval":
|
||||
raise ValueError(
|
||||
"Cannot process approval response: No tool call is currently awaiting approval. "
|
||||
"Please send a regular message to interact with the agent."
|
||||
)
|
||||
if input_messages[0].approval_request_id != current_in_context_messages[-1].id:
|
||||
if (
|
||||
approval_message.tool_call_id != current_in_context_messages[-1].id
|
||||
and approval_message.tool_call_id != current_in_context_messages[-1].tool_calls[0].id
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid approval request ID. Expected '{current_in_context_messages[-1].id}' "
|
||||
f"but received '{input_messages[0].approval_request_id}'."
|
||||
f"Invalid tool call ID. Expected '{current_in_context_messages[-1].tool_calls[0].id}' "
|
||||
f"but received '{approval_message.tool_call_id}'."
|
||||
)
|
||||
new_in_context_messages = create_approval_response_message_from_input(agent_state=agent_state, input_message=input_messages[0])
|
||||
else:
|
||||
|
||||
@@ -19,16 +19,17 @@ from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
|
||||
from letta.errors import ContextWindowExceededError, LLMError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessage, MessageType
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.schemas.step import StepProgression
|
||||
from letta.schemas.step_metrics import StepMetrics
|
||||
@@ -397,6 +398,9 @@ class LettaAgentV3(LettaAgentV2):
|
||||
is_approval=approval_response.approve if approval_response is not None else False,
|
||||
is_denial=(approval_response.approve == False) if approval_response is not None else False,
|
||||
denial_reason=approval_response.denial_reason if approval_response is not None else None,
|
||||
tool_return=approval_response.approvals[0]
|
||||
if approval_response and approval_response.approvals and isinstance(approval_response.approvals[0], ToolReturn)
|
||||
else None,
|
||||
)
|
||||
# NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op")
|
||||
|
||||
@@ -542,6 +546,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
is_approval: bool | None = None,
|
||||
is_denial: bool | None = None,
|
||||
denial_reason: str | None = None,
|
||||
tool_return: ToolReturn | None = None,
|
||||
) -> tuple[list[Message], bool, LettaStopReason | None]:
|
||||
"""
|
||||
Handle the final AI response once streaming completes, execute / validate the
|
||||
@@ -553,29 +558,45 @@ class LettaAgentV3(LettaAgentV2):
|
||||
else:
|
||||
tool_call_id: str = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
if is_denial:
|
||||
if is_denial or tool_return is not None:
|
||||
continue_stepping = True
|
||||
stop_reason = None
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call.function.name,
|
||||
function_arguments={},
|
||||
tool_execution_result=ToolExecutionResult(status="error"),
|
||||
tool_call_id=tool_call_id,
|
||||
function_response=f"Error: request to call tool denied. User reason: {denial_reason}",
|
||||
timezone=agent_state.timezone,
|
||||
continue_stepping=continue_stepping,
|
||||
# NOTE: we may need to change this to not have a "heartbeat" prefix for v3?
|
||||
heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.",
|
||||
reasoning_content=None,
|
||||
pre_computed_assistant_message_id=None,
|
||||
step_id=step_id,
|
||||
run_id=run_id,
|
||||
is_approval_response=True,
|
||||
force_set_request_heartbeat=False,
|
||||
add_heartbeat_on_continue=False,
|
||||
)
|
||||
if tool_return is not None:
|
||||
tool_call_messages = [
|
||||
Message(
|
||||
role=MessageRole.tool,
|
||||
content=[TextContent(text=tool_return.func_response)],
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
tool_calls=[],
|
||||
tool_call_id=tool_return.tool_call_id,
|
||||
created_at=get_utc_time(),
|
||||
tool_returns=[tool_return],
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
]
|
||||
else:
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call.function.name,
|
||||
function_arguments={},
|
||||
tool_execution_result=ToolExecutionResult(status="error"),
|
||||
tool_call_id=tool_call_id,
|
||||
function_response=f"Error: request to call tool denied. User reason: {denial_reason}",
|
||||
timezone=agent_state.timezone,
|
||||
continue_stepping=continue_stepping,
|
||||
# NOTE: we may need to change this to not have a "heartbeat" prefix for v3?
|
||||
heartbeat_reason=f"{NON_USER_MSG_PREFIX}Continuing: user denied request to call tool.",
|
||||
reasoning_content=None,
|
||||
pre_computed_assistant_message_id=None,
|
||||
step_id=step_id,
|
||||
run_id=run_id,
|
||||
is_approval_response=True,
|
||||
force_set_request_heartbeat=False,
|
||||
add_heartbeat_on_continue=False,
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
|
||||
# Set run_id on all messages before persisting
|
||||
|
||||
@@ -33,6 +33,7 @@ from letta.schemas.letta_message import (
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolCallMessage,
|
||||
ToolReturn as LettaToolReturn,
|
||||
ToolReturnMessage,
|
||||
UserMessage,
|
||||
)
|
||||
@@ -344,17 +345,29 @@ class Message(BaseMessage):
|
||||
messages.append(approval_request_message)
|
||||
else:
|
||||
if self.approvals:
|
||||
first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)][0]
|
||||
first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)]
|
||||
|
||||
def maybe_convert_tool_return_message(maybe_tool_return: ToolReturn):
|
||||
if isinstance(maybe_tool_return, LettaToolReturn):
|
||||
return LettaToolReturn(
|
||||
tool_call_id=maybe_tool_return.tool_call_id,
|
||||
status=maybe_tool_return.status,
|
||||
tool_return=maybe_tool_return.tool_return,
|
||||
stdout=maybe_tool_return.stdout,
|
||||
stderr=maybe_tool_return.stderr,
|
||||
)
|
||||
return maybe_tool_return
|
||||
|
||||
approval_response_message = ApprovalResponseMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
otid=self.otid,
|
||||
approvals=self.approvals,
|
||||
approvals=[maybe_convert_tool_return_message(approval) for approval in self.approvals],
|
||||
run_id=self.run_id,
|
||||
# TODO: temporary populate these fields for backwards compatibility
|
||||
approve=first_approval.approve,
|
||||
approval_request_id=first_approval.tool_call_id,
|
||||
reason=first_approval.reason,
|
||||
approve=first_approval[0].approve if first_approval else None,
|
||||
approval_request_id=first_approval[0].tool_call_id if first_approval else None,
|
||||
reason=first_approval[0].reason if first_approval else None,
|
||||
)
|
||||
else:
|
||||
approval_response_message = ApprovalResponseMessage(
|
||||
|
||||
@@ -27,6 +27,7 @@ from letta.otel.metric_registry import MetricRegistry
|
||||
from letta.otel.tracing import tracer
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import ToolReturn as LettaToolReturn
|
||||
from letta.schemas.letta_message_content import (
|
||||
OmittedReasoningContent,
|
||||
ReasoningContent,
|
||||
@@ -169,6 +170,20 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti
|
||||
|
||||
|
||||
def create_approval_response_message_from_input(agent_state: AgentState, input_message: ApprovalCreate) -> List[Message]:
|
||||
def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn):
|
||||
if isinstance(maybe_tool_return, LettaToolReturn):
|
||||
packaged_function_response = package_function_response(
|
||||
maybe_tool_return.status == "success", maybe_tool_return.tool_return, agent_state.timezone
|
||||
)
|
||||
return ToolReturn(
|
||||
tool_call_id=maybe_tool_return.tool_call_id,
|
||||
status=maybe_tool_return.status,
|
||||
func_response=packaged_function_response,
|
||||
stdout=maybe_tool_return.stdout,
|
||||
stderr=maybe_tool_return.stderr,
|
||||
)
|
||||
return maybe_tool_return
|
||||
|
||||
return [
|
||||
Message(
|
||||
role=MessageRole.approval,
|
||||
@@ -177,6 +192,7 @@ def create_approval_response_message_from_input(agent_state: AgentState, input_m
|
||||
approval_request_id=input_message.approval_request_id,
|
||||
approve=input_message.approve,
|
||||
denial_reason=input_message.reason,
|
||||
approvals=[maybe_convert_tool_return_message(approval) for approval in input_message.approvals],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
|
||||
with pytest.raises(ApiError, match="Invalid approval request ID"):
|
||||
with pytest.raises(ApiError, match="Invalid tool call ID"):
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[ApprovalCreate(approve=True, approval_request_id=FAKE_REQUEST_ID)],
|
||||
|
||||
Reference in New Issue
Block a user