From 505c9cff57a4788caef6e23fc61a4b4551341763 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 16 Oct 2025 14:09:40 -0700 Subject: [PATCH] feat: add testing for new hitl paths (#5493) --- letta/server/rest_api/utils.py | 21 ++- tests/integration_test_human_in_the_loop.py | 195 +++++++++++++++----- 2 files changed, 165 insertions(+), 51 deletions(-) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index c9cfffb4..ccefef89 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -36,7 +36,7 @@ from letta.schemas.letta_message_content import ( TextContent, ) from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import ApprovalCreate, Message, MessageCreate, ToolReturn +from letta.schemas.message import ApprovalCreate, ApprovalReturn, Message, MessageCreate, ToolReturn from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -172,7 +172,16 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti def create_approval_response_message_from_input( agent_state: AgentState, input_message: ApprovalCreate, run_id: Optional[str] = None ) -> List[Message]: - def maybe_convert_tool_return_message(maybe_tool_return: LettaToolReturn): + if not input_message.approvals: + input_message.approvals = [ + ApprovalCreate(approve=input_message.approve, tool_call_id=input_message.approval_request_id, reason=input_message.reason) + ] + approval_messages = [ + a for a in input_message.approvals if isinstance(a, ApprovalReturn) or (isinstance(a, dict) and a.get("type") == "approval") + ] + first_approval = approval_messages[0] if approval_messages else None + + def maybe_convert_return_message(maybe_tool_return): if isinstance(maybe_tool_return, LettaToolReturn): packaged_function_response = package_function_response( maybe_tool_return.status == "success", maybe_tool_return.tool_return, agent_state.timezone @@ -191,10 +200,10 @@ def create_approval_response_message_from_input( role=MessageRole.approval, agent_id=agent_state.id, model=agent_state.llm_config.model, - approval_request_id=input_message.approval_request_id, - approve=input_message.approve, - denial_reason=input_message.reason, - approvals=[maybe_convert_tool_return_message(approval) for approval in input_message.approvals], + approval_request_id=first_approval.tool_call_id if first_approval else input_message.approval_request_id, + approve=first_approval.approve if first_approval else input_message.approve, + denial_reason=first_approval.reason if first_approval else input_message.reason, + approvals=[maybe_convert_return_message(approval) for approval in input_message.approvals], run_id=run_id, ) ] diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index fab9d261..7bf58059 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -161,7 +161,19 @@ def test_send_approval_without_pending_request(client, agent): with pytest.raises(ApiError, match="No tool call is currently awaiting approval"): client.agents.messages.create( agent_id=agent.id, - messages=[ApprovalCreate(approve=True, approval_request_id=FAKE_REQUEST_ID)], + messages=[ + ApprovalCreate( + approve=True, # legacy + approval_request_id=FAKE_REQUEST_ID, # legacy + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": FAKE_REQUEST_ID, + }, + ], + ), + ], ) @@ -187,7 +199,19 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): with pytest.raises(ApiError, match="Invalid tool call IDs"): client.agents.messages.create( agent_id=agent.id, - messages=[ApprovalCreate(approve=True, approval_request_id=FAKE_REQUEST_ID)], + messages=[ + ApprovalCreate( + approve=True, # legacy + approval_request_id=FAKE_REQUEST_ID, # legacy + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": FAKE_REQUEST_ID, + }, + ], + ), + ], ) @@ -227,16 +251,26 @@ def test_send_message_after_turning_off_requires_approval( agent: AgentState, approval_tool_fixture: Tool, ) -> None: - response = client.agents.messages.create_stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True) - messages = accumulate_chunks(response) - approval_request_id = messages[0].id + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + approval_request_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ ApprovalCreate( - approve=True, - approval_request_id=approval_request_id, + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], ), ], stream_tokens=True, @@ -308,8 +342,15 @@ def test_approve_tool_call_request( agent_id=agent.id, messages=[ ApprovalCreate( - approve=True, - approval_request_id=approval_request_id, + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], ), ], stream_tokens=True, @@ -347,7 +388,8 @@ def test_approve_cursor_fetch( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id + last_message_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) assert len(messages) == 4 @@ -355,29 +397,36 @@ def test_approve_cursor_fetch( assert messages[1].message_type == "reasoning_message" assert messages[2].message_type == "assistant_message" assert messages[3].message_type == "approval_request_message" - assert messages[3].id == approval_request_id # Ensure no request_heartbeat on approval request import json as _json _args = _json.loads(messages[3].tool_call.arguments) assert "request_heartbeat" not in _args - last_message_cursor = approval_request_id client.agents.messages.create( agent_id=agent.id, messages=[ ApprovalCreate( - approve=True, - approval_request_id=approval_request_id, + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], ), ], ) - messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) assert len(messages) == 2 or len(messages) == 4 assert messages[0].message_type == "approval_response_message" - assert messages[0].approval_request_id == approval_request_id + assert messages[0].approval_request_id == tool_call_id assert messages[0].approve is True + assert messages[0].approvals[0]["approve"] is True + assert messages[0].approvals[0]["tool_call_id"] == tool_call_id assert messages[1].message_type == "tool_return_message" assert messages[1].status == "success" if len(messages) == 4: @@ -393,14 +442,21 @@ def test_approve_and_follow_up( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id client.agents.messages.create( agent_id=agent.id, messages=[ ApprovalCreate( - approve=True, - approval_request_id=approval_request_id, + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], ), ], ) @@ -436,7 +492,7 @@ def test_approve_and_follow_up_with_error( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id # Mock the streaming adapter to return llm invocation failure on the follow up turn with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): @@ -444,8 +500,15 @@ def test_approve_and_follow_up_with_error( agent_id=agent.id, messages=[ ApprovalCreate( - approve=True, - approval_request_id=approval_request_id, + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], ), ], stream_tokens=True, @@ -489,16 +552,23 @@ def test_deny_tool_call_request( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id tool_call_id = response.messages[2].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ ApprovalCreate( - approve=False, - approval_request_id=approval_request_id, - reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "approval", + "approve": False, + "tool_call_id": tool_call_id, + "reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}", + }, + ], ), ], ) @@ -526,7 +596,8 @@ def test_deny_cursor_fetch( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id + last_message_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) assert len(messages) == 4 @@ -534,28 +605,38 @@ def test_deny_cursor_fetch( assert messages[1].message_type == "reasoning_message" assert messages[2].message_type == "assistant_message" assert messages[3].message_type == "approval_request_message" - assert messages[3].id == approval_request_id + assert messages[3].tool_call.tool_call_id == tool_call_id # Ensure no request_heartbeat on approval request # import json as _json # _args = _json.loads(messages[2].tool_call.arguments) # assert "request_heartbeat" not in _args - last_message_cursor = approval_request_id client.agents.messages.create( agent_id=agent.id, messages=[ ApprovalCreate( - approve=False, - approval_request_id=approval_request_id, - reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "approval", + "approve": False, + "tool_call_id": tool_call_id, + "reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}", + }, + ], ), ], ) - messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) assert len(messages) == 4 assert messages[0].message_type == "approval_response_message" + assert messages[0].approvals[0]["approve"] == False + assert messages[0].approvals[0]["tool_call_id"] == tool_call_id + assert messages[0].approvals[0]["reason"] == f"You don't need to call the tool, the secret code is {SECRET_CODE}" assert messages[1].message_type == "tool_return_message" assert messages[1].status == "error" assert messages[2].message_type == "reasoning_message" @@ -570,15 +651,23 @@ def test_deny_and_follow_up( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id client.agents.messages.create( agent_id=agent.id, messages=[ ApprovalCreate( - approve=False, - approval_request_id=approval_request_id, - reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "approval", + "approve": False, + "tool_call_id": tool_call_id, + "reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}", + }, + ], ), ], ) @@ -607,15 +696,23 @@ def test_deny_with_context_check( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ ApprovalCreate( - approve=False, - approval_request_id=approval_request_id, - reason="Cancelled by user. Instead of responding, wait for next user input before replying.", + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy + approvals=[ + { + "type": "approval", + "approve": False, + "tool_call_id": tool_call_id, + "reason": "Cancelled by user. Instead of responding, wait for next user input before replying.", + }, + ], ), ], stream_tokens=True, @@ -639,7 +736,7 @@ def test_deny_and_follow_up_with_error( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - approval_request_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id # Mock the streaming adapter to return llm invocation failure on the follow up turn with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): @@ -647,9 +744,17 @@ def test_deny_and_follow_up_with_error( agent_id=agent.id, messages=[ ApprovalCreate( - approve=False, - approval_request_id=approval_request_id, - reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "approval", + "approve": False, + "tool_call_id": tool_call_id, + "reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}", + }, + ], ), ], stream_tokens=True,