feat: add comprehensive testing for client side tool calling (#5539)

This commit is contained in:
cthomas
2025-10-17 12:40:57 -07:00
committed by Caren Thomas
parent 69b15d606c
commit 5a475fd1a5
2 changed files with 370 additions and 46 deletions

View File

@@ -27,7 +27,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessage, MessageType
from letta.schemas.letta_message import ApprovalReturn, LettaMessage, MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
@@ -318,19 +318,19 @@ class LettaAgentV3(LettaAgentV2):
pending_tool_calls = {
backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id: a
for a in approval_response.approvals
if a.type == "approval" and a.approve
if isinstance(a, ApprovalReturn) and a.approve
}
tool_calls = [t for t in approval_request.tool_calls if t.id in pending_tool_calls]
# Get tool calls that were denied
denies = {d.tool_call_id: d for d in approval_response.approvals if d.type == "approval" and not d.approve}
denies = {d.tool_call_id: d for d in approval_response.approvals if isinstance(d, ApprovalReturn) and not d.approve}
tool_call_denials = [
ToolCallDenial(**t.model_dump(), reason=denies.get(t.id).reason) for t in approval_request.tool_calls if t.id in denies
]
# Get tool calls that were executed client side
if approval_response.approvals:
tool_returns = [r for r in approval_response.approvals if r.type == "tool"]
tool_returns = [r for r in approval_response.approvals if isinstance(r, ToolReturn)]
step_id = approval_request.step_id
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)

View File

@@ -220,7 +220,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
# ------------------------------
def test_send_message_with_requires_approval_tool(
def test_invoke_approval_request(
client: Letta,
agent: AgentState,
) -> None:
@@ -246,7 +246,45 @@ def test_send_message_with_requires_approval_tool(
assert messages[4].message_type == "usage_statistics"
def test_send_message_after_turning_off_requires_approval(
def test_invoke_approval_request_with_context_check(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
{
"type": "approval",
"approve": True,
"tool_call_id": tool_call_id,
},
],
),
],
stream_tokens=True,
)
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
raise e
def test_invoke_tool_after_turning_off_requires_approval(
client: Letta,
agent: AgentState,
approval_tool_fixture: Tool,
@@ -434,6 +472,44 @@ def test_approve_cursor_fetch(
assert messages[3].message_type == "assistant_message"
def test_approve_with_context_check(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
{
"type": "approval",
"approve": True,
"tool_call_id": tool_call_id,
},
],
),
],
stream_tokens=True,
)
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
raise e
def test_approve_and_follow_up(
client: Letta,
agent: AgentState,
@@ -643,6 +719,46 @@ def test_deny_cursor_fetch(
assert messages[3].message_type == "assistant_message"
def test_deny_with_context_check(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
approvals=[
{
"type": "approval",
"approve": False,
"tool_call_id": tool_call_id,
"reason": "Cancelled by user. Instead of responding, wait for next user input before replying.",
},
],
),
],
stream_tokens=True,
)
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
raise e
def test_deny_and_follow_up(
client: Letta,
agent: AgentState,
@@ -688,46 +804,6 @@ def test_deny_and_follow_up(
assert messages[3].message_type == "usage_statistics"
def test_deny_with_context_check(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
approvals=[
{
"type": "approval",
"approve": False,
"tool_call_id": tool_call_id,
"reason": "Cancelled by user. Instead of responding, wait for next user input before replying.",
},
],
),
],
stream_tokens=True,
)
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
raise e
def test_deny_and_follow_up_with_error(
client: Letta,
agent: AgentState,
@@ -781,3 +857,251 @@ def test_deny_and_follow_up_with_error(
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"
# --------------------------------
# Client-Side Execution Test Cases
# --------------------------------
def test_client_side_tool_call_request(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
approvals=[
{
"type": "tool",
"tool_call_id": tool_call_id,
"tool_return": SECRET_CODE,
"status": "success",
},
],
),
],
)
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 5
assert messages[0].message_type == "tool_return_message"
assert messages[0].tool_call_id == tool_call_id
assert messages[0].status == "success"
assert messages[0].tool_return == SECRET_CODE
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
assert SECRET_CODE in messages[2].content
assert messages[3].message_type == "stop_reason"
assert messages[4].message_type == "usage_statistics"
def test_client_side_tool_call_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
last_message_id = response.messages[0].id
tool_call_id = response.messages[2].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 4
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
assert messages[3].message_type == "approval_request_message"
assert messages[3].tool_call.tool_call_id == tool_call_id
# Ensure no request_heartbeat on approval request
# import json as _json
# _args = _json.loads(messages[2].tool_call.arguments)
# assert "request_heartbeat" not in _args
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
approvals=[
{
"type": "tool",
"tool_call_id": tool_call_id,
"tool_return": SECRET_CODE,
"status": "success",
},
],
),
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
assert len(messages) == 4
assert messages[0].message_type == "approval_response_message"
assert messages[0].approvals[0]["type"] == "tool"
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[0].approvals[0]["tool_return"] == SECRET_CODE
assert messages[0].approvals[0]["status"] == "success"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "success"
assert messages[1].tool_call_id == tool_call_id
assert messages[1].tool_return == SECRET_CODE
assert messages[2].message_type == "reasoning_message"
assert messages[3].message_type == "assistant_message"
def test_client_side_tool_call_with_context_check(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
approvals=[
{
"type": "tool",
"tool_call_id": tool_call_id,
"tool_return": SECRET_CODE,
"status": "success",
},
],
),
],
stream_tokens=True,
)
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
raise e
def test_client_side_tool_call_and_follow_up(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
approvals=[
{
"type": "tool",
"tool_call_id": tool_call_id,
"tool_return": SECRET_CODE,
"status": "success",
},
],
),
],
)
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"
def test_client_side_tool_call_and_follow_up_with_error(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
approvals=[
{
"type": "tool",
"tool_call_id": tool_call_id,
"tool_return": SECRET_CODE,
"status": "success",
},
],
),
],
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "invalid_llm_response"
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"