test: make hitl tests pass using v1 sdk LET-6312 (#6353)

test: make hitl tests pass using v1 sdk
This commit is contained in:
cthomas
2025-11-24 16:22:54 -08:00
committed by Caren Thomas
parent 29e38a2a42
commit e4fb00fef8

View File

@@ -1,12 +1,15 @@
import logging
import uuid
from typing import List
from typing import Any, List
from unittest.mock import patch
import pytest
from letta_client import APIError, Letta
from letta_client.types import AgentState, MessageCreateParam
from letta_client.types import AgentState, MessageCreateParam, Tool
from letta_client.types.agents import ApprovalCreateParam
from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter
logger = logging.getLogger(__name__)
# ------------------------------
@@ -180,11 +183,11 @@ def agent(client: Letta, approval_tool_fixture, dice_tool_fixture) -> AgentState
def test_send_approval_without_pending_request(client, agent):
with pytest.raises(ApiError, match="No tool call is currently awaiting approval"):
with pytest.raises(APIError, match="No tool call is currently awaiting approval"):
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy
approval_request_id=FAKE_REQUEST_ID, # legacy
approvals=[
@@ -205,10 +208,10 @@ def test_send_user_message_with_pending_request(client, agent):
messages=USER_MESSAGE_TEST_APPROVAL,
)
with pytest.raises(ApiError, match="Please approve or deny the pending request before continuing"):
with pytest.raises(APIError, match="Please approve or deny the pending request before continuing"):
client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="hi")],
messages=[MessageCreateParam(role="user", content="hi")],
)
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
@@ -220,11 +223,11 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
messages=USER_MESSAGE_TEST_APPROVAL,
)
with pytest.raises(ApiError, match="Invalid tool call IDs"):
with pytest.raises(APIError, match="Invalid tool call IDs"):
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy
approval_request_id=FAKE_REQUEST_ID, # legacy
approvals=[
@@ -263,7 +266,7 @@ def test_invoke_approval_request(
assert messages[-1].tool_call.name == "get_secret_code_tool"
assert messages[-1].tool_calls is not None
assert len(messages[-1].tool_calls) == 1
assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool"
assert messages[-1].tool_calls[0].name == "get_secret_code_tool"
# v3/v1 path: approval request tool args must not include request_heartbeat
import json as _json
@@ -271,7 +274,7 @@ def test_invoke_approval_request(
_args = _json.loads(messages[-1].tool_call.arguments)
assert "request_heartbeat" not in _args
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
@@ -280,7 +283,7 @@ def test_invoke_approval_request_stream(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
stream_tokens=True,
@@ -295,7 +298,7 @@ def test_invoke_approval_request_stream(
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
approve_tool_call(client, agent.id, messages[-3].tool_call.tool_call_id)
@@ -311,10 +314,10 @@ def test_invoke_tool_after_turning_off_requires_approval(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -330,13 +333,13 @@ def test_invoke_tool_after_turning_off_requires_approval(
)
messages = accumulate_chunks(response)
client.agents.tools.modify_approval(
client.agents.tools.update_approval(
agent_id=agent.id,
tool_name=approval_tool_fixture.name,
requires_approval=False,
body_requires_approval=False,
)
response = client.agents.messages.create_stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True)
response = client.agents.messages.stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True)
messages = accumulate_chunks(response)
@@ -385,10 +388,10 @@ def test_approve_tool_call_request(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -417,7 +420,7 @@ def test_approve_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
@@ -425,7 +428,7 @@ def test_approve_cursor_fetch(
last_message_id = response.messages[0].id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
assert messages[0].message_type == "user_message"
assert messages[-1].message_type == "approval_request_message"
# Ensure no request_heartbeat on approval request
@@ -437,7 +440,7 @@ def test_approve_cursor_fetch(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -451,12 +454,12 @@ def test_approve_cursor_fetch(
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
assert messages[0].message_type == "approval_response_message"
assert messages[0].approval_request_id == tool_call_id
assert messages[0].approve is True
assert messages[0].approvals[0]["approve"] is True
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[0].approvals[0].approve is True
assert messages[0].approvals[0].tool_call_id == tool_call_id
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "success"
@@ -471,10 +474,10 @@ def test_approve_with_context_check(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -492,7 +495,7 @@ def test_approve_with_context_check(
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
@@ -512,7 +515,7 @@ def test_approve_and_follow_up(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -526,7 +529,7 @@ def test_approve_and_follow_up(
],
)
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
@@ -552,10 +555,10 @@ def test_approve_and_follow_up_with_error(
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -570,15 +573,11 @@ def test_approve_and_follow_up_with_error(
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "invalid_llm_response"
with pytest.raises(APIError, match="TEST: Mocked error"):
messages = accumulate_chunks(response)
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
@@ -610,10 +609,10 @@ def test_deny_tool_call_request(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -642,7 +641,7 @@ def test_deny_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
@@ -650,7 +649,7 @@ def test_deny_cursor_fetch(
last_message_id = response.messages[0].id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
assert messages[0].message_type == "user_message"
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call.tool_call_id == tool_call_id
@@ -663,7 +662,7 @@ def test_deny_cursor_fetch(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -679,11 +678,11 @@ def test_deny_cursor_fetch(
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
assert messages[0].message_type == "approval_response_message"
assert messages[0].approvals[0]["approve"] == False
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[0].approvals[0]["reason"] == f"You don't need to call the tool, the secret code is {SECRET_CODE}"
assert messages[0].approvals[0].approve == False
assert messages[0].approvals[0].tool_call_id == tool_call_id
assert messages[0].approvals[0].reason == f"You don't need to call the tool, the secret code is {SECRET_CODE}"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "error"
@@ -698,10 +697,10 @@ def test_deny_with_context_check(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
@@ -721,7 +720,7 @@ def test_deny_with_context_check(
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
@@ -741,7 +740,7 @@ def test_deny_and_follow_up(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -757,7 +756,7 @@ def test_deny_and_follow_up(
],
)
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
@@ -783,10 +782,10 @@ def test_deny_and_follow_up_with_error(
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -803,15 +802,11 @@ def test_deny_and_follow_up_with_error(
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "invalid_llm_response"
with pytest.raises(APIError, match="TEST: Mocked error"):
messages = accumulate_chunks(response)
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
@@ -839,10 +834,10 @@ def test_client_side_tool_call_request(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -873,7 +868,7 @@ def test_client_side_tool_call_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
@@ -881,7 +876,7 @@ def test_client_side_tool_call_cursor_fetch(
last_message_id = response.messages[0].id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
assert messages[0].message_type == "user_message"
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call.tool_call_id == tool_call_id
@@ -894,7 +889,7 @@ def test_client_side_tool_call_cursor_fetch(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -910,12 +905,12 @@ def test_client_side_tool_call_cursor_fetch(
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
assert messages[0].message_type == "approval_response_message"
assert messages[0].approvals[0]["type"] == "tool"
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[0].approvals[0]["tool_return"] == SECRET_CODE
assert messages[0].approvals[0]["status"] == "success"
assert messages[0].approvals[0].type == "tool"
assert messages[0].approvals[0].tool_call_id == tool_call_id
assert messages[0].approvals[0].tool_return == SECRET_CODE
assert messages[0].approvals[0].status == "success"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "success"
assert messages[1].tool_call_id == tool_call_id
@@ -932,10 +927,10 @@ def test_client_side_tool_call_with_context_check(
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy
@@ -955,7 +950,7 @@ def test_client_side_tool_call_with_context_check(
messages = accumulate_chunks(response)
try:
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
except Exception as e:
if len(messages) > 4:
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
@@ -975,7 +970,7 @@ def test_client_side_tool_call_and_follow_up(
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -991,7 +986,7 @@ def test_client_side_tool_call_and_follow_up(
],
)
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
@@ -1017,10 +1012,10 @@ def test_client_side_tool_call_and_follow_up_with_error(
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=True, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy
@@ -1037,15 +1032,11 @@ def test_client_side_tool_call_and_follow_up_with_error(
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "invalid_llm_response"
with pytest.raises(APIError, match="TEST: Mocked error"):
messages = accumulate_chunks(response)
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
@@ -1062,7 +1053,7 @@ def test_parallel_tool_calling(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
@@ -1073,32 +1064,32 @@ def test_parallel_tool_calling(
assert messages is not None
assert messages[-2].message_type == "tool_call_message"
assert len(messages[-2].tool_calls) == 1
assert messages[-2].tool_calls[0]["name"] == "roll_dice_tool"
assert "6" in messages[-2].tool_calls[0]["arguments"]
dice_tool_call_id = messages[-2].tool_calls[0]["tool_call_id"]
assert messages[-2].tool_calls[0].name == "roll_dice_tool"
assert "6" in messages[-2].tool_calls[0].arguments
dice_tool_call_id = messages[-2].tool_calls[0].tool_call_id
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call is not None
assert messages[-1].tool_call.name == "get_secret_code_tool"
assert len(messages[-1].tool_calls) == 3
assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool"
assert "hello world" in messages[-1].tool_calls[0]["arguments"]
approve_tool_call_id = messages[-1].tool_calls[0]["tool_call_id"]
assert messages[-1].tool_calls[1]["name"] == "get_secret_code_tool"
assert "hello letta" in messages[-1].tool_calls[1]["arguments"]
deny_tool_call_id = messages[-1].tool_calls[1]["tool_call_id"]
assert messages[-1].tool_calls[2]["name"] == "get_secret_code_tool"
assert "hello test" in messages[-1].tool_calls[2]["arguments"]
client_side_tool_call_id = messages[-1].tool_calls[2]["tool_call_id"]
assert messages[-1].tool_calls[0].name == "get_secret_code_tool"
assert "hello world" in messages[-1].tool_calls[0].arguments
approve_tool_call_id = messages[-1].tool_calls[0].tool_call_id
assert messages[-1].tool_calls[1].name == "get_secret_code_tool"
assert "hello letta" in messages[-1].tool_calls[1].arguments
deny_tool_call_id = messages[-1].tool_calls[1].tool_call_id
assert messages[-1].tool_calls[2].name == "get_secret_code_tool"
assert "hello test" in messages[-1].tool_calls[2].arguments
client_side_tool_call_id = messages[-1].tool_calls[2].tool_call_id
# ensure context is not bricked
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
ApprovalCreateParam(
approve=False, # legacy (passing incorrect value to ensure it is overridden)
approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden)
approvals=[
@@ -1130,16 +1121,16 @@ def test_parallel_tool_calling(
assert messages[0].message_type == "tool_return_message"
assert len(messages[0].tool_returns) == 4
for tool_return in messages[0].tool_returns:
if tool_return["tool_call_id"] == approve_tool_call_id:
assert tool_return["status"] == "success"
elif tool_return["tool_call_id"] == deny_tool_call_id:
assert tool_return["status"] == "error"
elif tool_return["tool_call_id"] == client_side_tool_call_id:
assert tool_return["status"] == "success"
assert tool_return["tool_return"] == SECRET_CODE
if tool_return.tool_call_id == approve_tool_call_id:
assert tool_return.status == "success"
elif tool_return.tool_call_id == deny_tool_call_id:
assert tool_return.status == "error"
elif tool_return.tool_call_id == client_side_tool_call_id:
assert tool_return.status == "success"
assert tool_return.tool_return == SECRET_CODE
else:
assert tool_return["tool_call_id"] == dice_tool_call_id
assert tool_return["status"] == "success"
assert tool_return.tool_call_id == dice_tool_call_id
assert tool_return.status == "success"
if len(messages) == 3:
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
@@ -1149,9 +1140,9 @@ def test_parallel_tool_calling(
assert messages[3].message_type == "tool_return_message"
# ensure context is not bricked
client.agents.context.retrieve(agent_id=agent.id)
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
assert len(messages) > 6
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
@@ -1161,7 +1152,7 @@ def test_parallel_tool_calling(
assert messages[5].message_type == "approval_response_message"
assert messages[6].message_type == "tool_return_message"
response = client.agents.messages.create_stream(
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,