test: make hitl testing pass (#6188)

This commit is contained in:
cthomas
2025-11-14 15:13:18 -08:00
committed by Caren Thomas
parent 848aa962b6
commit 41392cdb8a
2 changed files with 82 additions and 118 deletions

View File

@@ -320,6 +320,8 @@ class LettaAgentV3(LettaAgentV2):
# Raise if no chunks sent yet (response not started, can return error status code)
raise
else:
yield f"data: {self.stop_reason.model_dump_json()}\n\n"
# Mid-stream error: yield error event to client in SSE format
error_chunk = {
"error": {

View File

@@ -246,7 +246,7 @@ def test_send_user_message_with_pending_request(client, agent):
messages=[MessageCreate(role="user", content="hi")],
)
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
def test_send_approval_message_with_incorrect_request_id(client, agent):
@@ -273,7 +273,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
],
)
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
# ------------------------------
@@ -293,25 +293,22 @@ def test_invoke_approval_request(
messages = response.messages
assert messages is not None
assert len(messages) == 3
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "approval_request_message"
assert messages[2].tool_call is not None
assert messages[2].tool_call.name == "get_secret_code_tool"
assert messages[2].tool_calls is not None
assert len(messages[2].tool_calls) == 1
assert messages[2].tool_calls[0]["name"] == "get_secret_code_tool"
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call is not None
assert messages[-1].tool_call.name == "get_secret_code_tool"
assert messages[-1].tool_calls is not None
assert len(messages[-1].tool_calls) == 1
assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool"
# v3/v1 path: approval request tool args must not include request_heartbeat
import json as _json
_args = _json.loads(messages[2].tool_call.arguments)
_args = _json.loads(messages[-1].tool_call.arguments)
assert "request_heartbeat" not in _args
client.agents.context.retrieve(agent_id=agent.id)
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
def test_invoke_approval_request_stream(
@@ -327,18 +324,15 @@ def test_invoke_approval_request_stream(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 5
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "approval_request_message"
assert messages[2].tool_call is not None
assert messages[2].tool_call.name == "get_secret_code_tool"
assert messages[3].message_type == "stop_reason"
assert messages[4].message_type == "usage_statistics"
assert messages[-3].message_type == "approval_request_message"
assert messages[-3].tool_call is not None
assert messages[-3].tool_call.name == "get_secret_code_tool"
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
client.agents.context.retrieve(agent_id=agent.id)
approve_tool_call(client, agent.id, messages[2].tool_call.tool_call_id)
approve_tool_call(client, agent.id, messages[-3].tool_call.tool_call_id)
def test_invoke_tool_after_turning_off_requires_approval(
@@ -350,7 +344,7 @@ def test_invoke_tool_after_turning_off_requires_approval(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -424,7 +418,7 @@ def test_approve_tool_call_request(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -477,18 +471,16 @@ def test_approve_cursor_fetch(
messages=USER_MESSAGE_TEST_APPROVAL,
)
last_message_id = response.messages[0].id
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 4
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
assert messages[3].message_type == "approval_request_message"
assert messages[-1].message_type == "approval_request_message"
# Ensure no request_heartbeat on approval request
import json as _json
_args = _json.loads(messages[3].tool_call.arguments)
_args = _json.loads(messages[-1].tool_call.arguments)
assert "request_heartbeat" not in _args
client.agents.messages.create(
@@ -509,7 +501,6 @@ def test_approve_cursor_fetch(
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
assert len(messages) == 2 or len(messages) == 4
assert messages[0].message_type == "approval_response_message"
assert messages[0].approval_request_id == tool_call_id
assert messages[0].approve is True
@@ -517,9 +508,6 @@ def test_approve_cursor_fetch(
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "success"
if len(messages) == 4:
assert messages[2].message_type == "reasoning_message"
assert messages[3].message_type == "assistant_message"
def test_approve_with_context_check(
@@ -530,7 +518,7 @@ def test_approve_with_context_check(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -568,7 +556,7 @@ def test_approve_and_follow_up(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
client.agents.messages.create(
agent_id=agent.id,
@@ -618,7 +606,7 @@ def test_approve_and_follow_up_with_error(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
@@ -678,7 +666,7 @@ def test_deny_tool_call_request(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -702,15 +690,13 @@ def test_deny_tool_call_request(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 5
assert messages[0].message_type == "tool_return_message"
assert messages[0].tool_call_id == tool_call_id
assert messages[0].status == "error"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
assert SECRET_CODE in messages[2].content
assert messages[3].message_type == "stop_reason"
assert messages[4].message_type == "usage_statistics"
if messages[1].message_type == "assistant_message":
assert SECRET_CODE in messages[1].content
elif messages[2].message_type == "assistant_message":
assert SECRET_CODE in messages[2].content
def test_deny_cursor_fetch(
@@ -723,15 +709,12 @@ def test_deny_cursor_fetch(
messages=USER_MESSAGE_TEST_APPROVAL,
)
last_message_id = response.messages[0].id
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 4
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
assert messages[3].message_type == "approval_request_message"
assert messages[3].tool_call.tool_call_id == tool_call_id
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call.tool_call_id == tool_call_id
# Ensure no request_heartbeat on approval request
# import json as _json
@@ -758,15 +741,12 @@ def test_deny_cursor_fetch(
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
assert len(messages) == 4
assert messages[0].message_type == "approval_response_message"
assert messages[0].approvals[0]["approve"] == False
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
assert messages[0].approvals[0]["reason"] == f"You don't need to call the tool, the secret code is {SECRET_CODE}"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "error"
assert messages[2].message_type == "reasoning_message"
assert messages[3].message_type == "assistant_message"
def test_deny_with_context_check(
@@ -777,7 +757,7 @@ def test_deny_with_context_check(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -817,7 +797,7 @@ def test_deny_and_follow_up(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
client.agents.messages.create(
agent_id=agent.id,
@@ -847,11 +827,9 @@ def test_deny_and_follow_up(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"
assert len(messages) > 2
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
def test_deny_and_follow_up_with_error(
@@ -862,7 +840,7 @@ def test_deny_and_follow_up_with_error(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
@@ -902,11 +880,9 @@ def test_deny_and_follow_up_with_error(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"
assert len(messages) > 2
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
# --------------------------------
@@ -922,7 +898,7 @@ def test_client_side_tool_call_request(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -946,16 +922,16 @@ def test_client_side_tool_call_request(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 5
assert messages[0].message_type == "tool_return_message"
assert messages[0].tool_call_id == tool_call_id
assert messages[0].status == "success"
assert messages[0].tool_return == SECRET_CODE
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
assert SECRET_CODE in messages[2].content
assert messages[3].message_type == "stop_reason"
assert messages[4].message_type == "usage_statistics"
if messages[1].message_type == "assistant_message":
assert SECRET_CODE in messages[1].content
elif messages[2].message_type == "assistant_message":
assert SECRET_CODE in messages[2].content
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
def test_client_side_tool_call_cursor_fetch(
@@ -968,15 +944,12 @@ def test_client_side_tool_call_cursor_fetch(
messages=USER_MESSAGE_TEST_APPROVAL,
)
last_message_id = response.messages[0].id
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 4
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "assistant_message"
assert messages[3].message_type == "approval_request_message"
assert messages[3].tool_call.tool_call_id == tool_call_id
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call.tool_call_id == tool_call_id
# Ensure no request_heartbeat on approval request
# import json as _json
@@ -1003,7 +976,6 @@ def test_client_side_tool_call_cursor_fetch(
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
assert len(messages) == 4
assert messages[0].message_type == "approval_response_message"
assert messages[0].approvals[0]["type"] == "tool"
assert messages[0].approvals[0]["tool_call_id"] == tool_call_id
@@ -1013,8 +985,6 @@ def test_client_side_tool_call_cursor_fetch(
assert messages[1].status == "success"
assert messages[1].tool_call_id == tool_call_id
assert messages[1].tool_return == SECRET_CODE
assert messages[2].message_type == "reasoning_message"
assert messages[3].message_type == "assistant_message"
def test_client_side_tool_call_with_context_check(
@@ -1025,7 +995,7 @@ def test_client_side_tool_call_with_context_check(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -1065,7 +1035,7 @@ def test_client_side_tool_call_and_follow_up(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
client.agents.messages.create(
agent_id=agent.id,
@@ -1095,11 +1065,9 @@ def test_client_side_tool_call_and_follow_up(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"
assert len(messages) > 2
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
def test_client_side_tool_call_and_follow_up_with_error(
@@ -1110,7 +1078,7 @@ def test_client_side_tool_call_and_follow_up_with_error(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[2].tool_call.tool_call_id
tool_call_id = response.messages[-1].tool_call.tool_call_id
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
@@ -1150,11 +1118,9 @@ def test_client_side_tool_call_and_follow_up_with_error(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"
assert len(messages) > 2
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
def test_parallel_tool_calling(
@@ -1170,29 +1136,26 @@ def test_parallel_tool_calling(
messages = response.messages
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "tool_call_message"
assert len(messages[2].tool_calls) == 1
assert messages[2].tool_calls[0]["name"] == "roll_dice_tool"
assert "6" in messages[2].tool_calls[0]["arguments"]
dice_tool_call_id = messages[2].tool_calls[0]["tool_call_id"]
assert messages[-2].message_type == "tool_call_message"
assert len(messages[-2].tool_calls) == 1
assert messages[-2].tool_calls[0]["name"] == "roll_dice_tool"
assert "6" in messages[-2].tool_calls[0]["arguments"]
dice_tool_call_id = messages[-2].tool_calls[0]["tool_call_id"]
assert messages[3].message_type == "approval_request_message"
assert messages[3].tool_call is not None
assert messages[3].tool_call.name == "get_secret_code_tool"
assert messages[-1].message_type == "approval_request_message"
assert messages[-1].tool_call is not None
assert messages[-1].tool_call.name == "get_secret_code_tool"
assert len(messages[3].tool_calls) == 3
assert messages[3].tool_calls[0]["name"] == "get_secret_code_tool"
assert "hello world" in messages[3].tool_calls[0]["arguments"]
approve_tool_call_id = messages[3].tool_calls[0]["tool_call_id"]
assert messages[3].tool_calls[1]["name"] == "get_secret_code_tool"
assert "hello letta" in messages[3].tool_calls[1]["arguments"]
deny_tool_call_id = messages[3].tool_calls[1]["tool_call_id"]
assert messages[3].tool_calls[2]["name"] == "get_secret_code_tool"
assert "hello test" in messages[3].tool_calls[2]["arguments"]
client_side_tool_call_id = messages[3].tool_calls[2]["tool_call_id"]
assert len(messages[-1].tool_calls) == 3
assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool"
assert "hello world" in messages[-1].tool_calls[0]["arguments"]
approve_tool_call_id = messages[-1].tool_calls[0]["tool_call_id"]
assert messages[-1].tool_calls[1]["name"] == "get_secret_code_tool"
assert "hello letta" in messages[-1].tool_calls[1]["arguments"]
deny_tool_call_id = messages[-1].tool_calls[1]["tool_call_id"]
assert messages[-1].tool_calls[2]["name"] == "get_secret_code_tool"
assert "hello test" in messages[-1].tool_calls[2]["arguments"]
client_side_tool_call_id = messages[-1].tool_calls[2]["tool_call_id"]
# ensure context is not bricked
client.agents.context.retrieve(agent_id=agent.id)
@@ -1300,15 +1263,14 @@ def test_agent_records_last_stop_reason_after_approval_flow(
# Verify we got an approval request
messages = response.messages
assert messages is not None
assert len(messages) == 3
assert messages[2].message_type == "approval_request_message"
assert messages[-1].message_type == "approval_request_message"
# Check agent after approval request (run should be paused with requires_approval)
agent_after_request = client.agents.retrieve(agent_id=agent.id)
assert agent_after_request.last_stop_reason == "requires_approval"
# Approve the tool call
approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id)
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
# Check agent after approval (run should complete with end_turn or similar)
agent_after_approval = client.agents.retrieve(agent_id=agent.id)