feat: add comprehensive testing for client side tool calling (#5539)
This commit is contained in:
@@ -27,7 +27,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessage, MessageType
|
||||
from letta.schemas.letta_message import ApprovalReturn, LettaMessage, MessageType
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
@@ -318,19 +318,19 @@ class LettaAgentV3(LettaAgentV2):
|
||||
pending_tool_calls = {
|
||||
backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id: a
|
||||
for a in approval_response.approvals
|
||||
if a.type == "approval" and a.approve
|
||||
if isinstance(a, ApprovalReturn) and a.approve
|
||||
}
|
||||
tool_calls = [t for t in approval_request.tool_calls if t.id in pending_tool_calls]
|
||||
|
||||
# Get tool calls that were denied
|
||||
denies = {d.tool_call_id: d for d in approval_response.approvals if d.type == "approval" and not d.approve}
|
||||
denies = {d.tool_call_id: d for d in approval_response.approvals if isinstance(d, ApprovalReturn) and not d.approve}
|
||||
tool_call_denials = [
|
||||
ToolCallDenial(**t.model_dump(), reason=denies.get(t.id).reason) for t in approval_request.tool_calls if t.id in denies
|
||||
]
|
||||
|
||||
# Get tool calls that were executed client side
|
||||
if approval_response.approvals:
|
||||
tool_returns = [r for r in approval_response.approvals if r.type == "tool"]
|
||||
tool_returns = [r for r in approval_response.approvals if isinstance(r, ToolReturn)]
|
||||
|
||||
step_id = approval_request.step_id
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
|
||||
|
||||
@@ -220,7 +220,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def test_send_message_with_requires_approval_tool(
|
||||
def test_invoke_approval_request(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
@@ -246,7 +246,45 @@ def test_send_message_with_requires_approval_tool(
|
||||
assert messages[4].message_type == "usage_statistics"
|
||||
|
||||
|
||||
def test_send_message_after_turning_off_requires_approval(
|
||||
def test_invoke_approval_request_with_context_check(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=False, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approvals=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": True,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
try:
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
except Exception as e:
|
||||
if len(messages) > 4:
|
||||
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
||||
raise e
|
||||
|
||||
|
||||
def test_invoke_tool_after_turning_off_requires_approval(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
approval_tool_fixture: Tool,
|
||||
@@ -434,6 +472,44 @@ def test_approve_cursor_fetch(
|
||||
assert messages[3].message_type == "assistant_message"
|
||||
|
||||
|
||||
def test_approve_with_context_check(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=False, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approvals=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": True,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
try:
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
except Exception as e:
|
||||
if len(messages) > 4:
|
||||
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
||||
raise e
|
||||
|
||||
|
||||
def test_approve_and_follow_up(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
@@ -643,6 +719,46 @@ def test_deny_cursor_fetch(
|
||||
assert messages[3].message_type == "assistant_message"
|
||||
|
||||
|
||||
def test_deny_with_context_check(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=True, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
|
||||
approvals=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": False,
|
||||
"tool_call_id": tool_call_id,
|
||||
"reason": "Cancelled by user. Instead of responding, wait for next user input before replying.",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
try:
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
except Exception as e:
|
||||
if len(messages) > 4:
|
||||
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
||||
raise e
|
||||
|
||||
|
||||
def test_deny_and_follow_up(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
@@ -688,46 +804,6 @@ def test_deny_and_follow_up(
|
||||
assert messages[3].message_type == "usage_statistics"
|
||||
|
||||
|
||||
def test_deny_with_context_check(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=True, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
|
||||
approvals=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": False,
|
||||
"tool_call_id": tool_call_id,
|
||||
"reason": "Cancelled by user. Instead of responding, wait for next user input before replying.",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
try:
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
except Exception as e:
|
||||
if len(messages) > 4:
|
||||
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
||||
raise e
|
||||
|
||||
|
||||
def test_deny_and_follow_up_with_error(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
@@ -781,3 +857,251 @@ def test_deny_and_follow_up_with_error(
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "stop_reason"
|
||||
assert messages[3].message_type == "usage_statistics"
|
||||
|
||||
|
||||
# --------------------------------
|
||||
# Client-Side Execution Test Cases
|
||||
# --------------------------------
|
||||
|
||||
|
||||
def test_client_side_tool_call_request(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=True, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
|
||||
approvals=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_return": SECRET_CODE,
|
||||
"status": "success",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) == 5
|
||||
assert messages[0].message_type == "tool_return_message"
|
||||
assert messages[0].tool_call_id == tool_call_id
|
||||
assert messages[0].status == "success"
|
||||
assert messages[0].tool_return == SECRET_CODE
|
||||
assert messages[1].message_type == "reasoning_message"
|
||||
assert messages[2].message_type == "assistant_message"
|
||||
assert SECRET_CODE in messages[2].content
|
||||
assert messages[3].message_type == "stop_reason"
|
||||
assert messages[4].message_type == "usage_statistics"
|
||||
|
||||
|
||||
def test_client_side_tool_call_cursor_fetch(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
last_message_id = response.messages[0].id
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
|
||||
assert len(messages) == 4
|
||||
assert messages[0].message_type == "user_message"
|
||||
assert messages[1].message_type == "reasoning_message"
|
||||
assert messages[2].message_type == "assistant_message"
|
||||
assert messages[3].message_type == "approval_request_message"
|
||||
assert messages[3].tool_call.tool_call_id == tool_call_id
|
||||
# Ensure no request_heartbeat on approval request
|
||||
# import json as _json
|
||||
|
||||
# _args = _json.loads(messages[2].tool_call.arguments)
|
||||
# assert "request_heartbeat" not in _args
|
||||
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=True, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
|
||||
approvals=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_return": SECRET_CODE,
|
||||
"status": "success",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
|
||||
assert len(messages) == 4
|
||||
assert messages[0].message_type == "approval_response_message"
|
||||
assert messages[0].approvals[0]["type"] == "tool"
|
||||
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
|
||||
assert messages[0].approvals[0]["tool_return"] == SECRET_CODE
|
||||
assert messages[0].approvals[0]["status"] == "success"
|
||||
assert messages[1].message_type == "tool_return_message"
|
||||
assert messages[1].status == "success"
|
||||
assert messages[1].tool_call_id == tool_call_id
|
||||
assert messages[1].tool_return == SECRET_CODE
|
||||
assert messages[2].message_type == "reasoning_message"
|
||||
assert messages[3].message_type == "assistant_message"
|
||||
|
||||
|
||||
def test_client_side_tool_call_with_context_check(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=True, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
|
||||
approvals=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_return": SECRET_CODE,
|
||||
"status": "success",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
try:
|
||||
client.agents.context.retrieve(agent_id=agent.id)
|
||||
except Exception as e:
|
||||
if len(messages) > 4:
|
||||
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
||||
raise e
|
||||
|
||||
|
||||
def test_client_side_tool_call_and_follow_up(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=True, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
|
||||
approvals=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_return": SECRET_CODE,
|
||||
"status": "success",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_FOLLOW_UP,
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) == 4
|
||||
assert messages[0].message_type == "reasoning_message"
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "stop_reason"
|
||||
assert messages[3].message_type == "usage_statistics"
|
||||
|
||||
|
||||
def test_client_side_tool_call_and_follow_up_with_error(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
# Mock the streaming adapter to return llm invocation failure on the follow up turn
|
||||
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=True, # legacy (passing incorrect value to ensure it is overridden)
|
||||
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
|
||||
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
|
||||
approvals=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_return": SECRET_CODE,
|
||||
"status": "success",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
|
||||
assert stop_reason_message
|
||||
assert stop_reason_message.stop_reason == "invalid_llm_response"
|
||||
|
||||
# Ensure that agent is not bricked
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_FOLLOW_UP,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) == 4
|
||||
assert messages[0].message_type == "reasoning_message"
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "stop_reason"
|
||||
assert messages[3].message_type == "usage_statistics"
|
||||
|
||||
Reference in New Issue
Block a user