From 1d71468ab279bd10c84ab9c6b3e1445ad049f7f0 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 19 Nov 2025 16:10:45 -0800 Subject: [PATCH] feat: don't yield tool return message back in hitl [LET-6012] (#6219) feat: don't yield tool return message back in hitl --- letta/agents/letta_agent_v3.py | 2 +- tests/integration_test_human_in_the_loop.py | 50 +- .../integration_test_human_in_the_loop.py | 631 ++++++++---------- 3 files changed, 301 insertions(+), 382 deletions(-) diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 2f17fc23..8b224924 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -638,7 +638,7 @@ class LettaAgentV3(LettaAgentV2): self.response_messages.extend(aggregated_persisted[new_message_idx:]) self.response_messages_for_metadata.extend(aggregated_persisted[new_message_idx:]) # Track for job metadata - if llm_adapter.supports_token_streaming(): + if llm_adapter.supports_token_streaming() and tool_calls: # Stream each tool return if tools were executed response_tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"] for tr in response_tool_returns: diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 1a343253..2e6abc03 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -441,24 +441,11 @@ def test_approve_tool_call_request( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 3 or len(messages) == 5 or len(messages) == 6 assert messages[0].message_type == "tool_return_message" assert messages[0].tool_call_id == tool_call_id assert messages[0].status == "success" - if len(messages) == 4: - assert messages[1].message_type == "stop_reason" - assert messages[2].message_type == "usage_statistics" - elif len(messages) == 5: - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" - elif len(messages) == 6: - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "tool_call_message" - assert messages[3].message_type == "tool_return_message" - assert messages[4].message_type == "stop_reason" - assert messages[5].message_type == "usage_statistics" + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_approve_cursor_fetch( @@ -474,7 +461,6 @@ def test_approve_cursor_fetch( tool_call_id = response.messages[-1].tool_call.tool_call_id messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - assert len(messages) == 4 assert messages[0].message_type == "user_message" assert messages[-1].message_type == "approval_request_message" # Ensure no request_heartbeat on approval request @@ -584,18 +570,9 @@ def test_approve_and_follow_up( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 4 or len(messages) == 5 - if len(messages) == 4: - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" - elif len(messages) == 5: - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "tool_call_message" - assert messages[2].message_type == "tool_return_message" - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" + assert messages[0].message_type in ["reasoning_message", "assistant_message", "tool_call_message"] + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_approve_and_follow_up_with_error( @@ -690,13 +667,10 @@ def test_deny_tool_call_request( messages = accumulate_chunks(response) assert messages is not None - assert messages[0].message_type == "tool_return_message" - assert messages[0].tool_call_id == tool_call_id - assert messages[0].status == "error" - if messages[1].message_type == "assistant_message": + if messages[0].message_type == "assistant_message": + assert SECRET_CODE in messages[0].content + elif messages[1].message_type == "assistant_message": assert SECRET_CODE in messages[1].content - elif messages[2].message_type == "assistant_message": - assert SECRET_CODE in messages[2].content def test_deny_cursor_fetch( @@ -922,13 +896,9 @@ def test_client_side_tool_call_request( messages = accumulate_chunks(response) assert messages is not None - assert messages[0].message_type == "tool_return_message" - assert messages[0].tool_call_id == tool_call_id - assert messages[0].status == "success" - assert messages[0].tool_return == SECRET_CODE - if messages[1].message_type == "assistant_message": + if messages[0].message_type == "assistant_message": assert SECRET_CODE in messages[1].content - elif messages[2].message_type == "assistant_message": + elif messages[1].message_type == "assistant_message": assert SECRET_CODE in messages[2].content assert messages[-2].message_type == "stop_reason" assert messages[-1].message_type == "usage_statistics" diff --git a/tests/sdk_v1/integration/integration_test_human_in_the_loop.py b/tests/sdk_v1/integration/integration_test_human_in_the_loop.py index d9d171a1..fbfeda81 100644 --- a/tests/sdk_v1/integration/integration_test_human_in_the_loop.py +++ b/tests/sdk_v1/integration/integration_test_human_in_the_loop.py @@ -180,11 +180,11 @@ def agent(client: Letta, approval_tool_fixture, dice_tool_fixture) -> AgentState def test_send_approval_without_pending_request(client, agent): - with pytest.raises(APIError, match="No tool call is currently awaiting approval"): + with pytest.raises(ApiError, match="No tool call is currently awaiting approval"): client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy approval_request_id=FAKE_REQUEST_ID, # legacy approvals=[ @@ -205,13 +205,13 @@ def test_send_user_message_with_pending_request(client, agent): messages=USER_MESSAGE_TEST_APPROVAL, ) - with pytest.raises(APIError, match="Please approve or deny the pending request before continuing"): + with pytest.raises(ApiError, match="Please approve or deny the pending request before continuing"): client.agents.messages.create( agent_id=agent.id, - messages=[MessageCreateParam(role="user", content="hi")], + messages=[MessageCreate(role="user", content="hi")], ) - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id) def test_send_approval_message_with_incorrect_request_id(client, agent): @@ -220,11 +220,11 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): messages=USER_MESSAGE_TEST_APPROVAL, ) - with pytest.raises(APIError, match="Invalid tool call IDs"): + with pytest.raises(ApiError, match="Invalid tool call IDs"): client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy approval_request_id=FAKE_REQUEST_ID, # legacy approvals=[ @@ -238,7 +238,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): ], ) - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id) # ------------------------------ @@ -258,32 +258,29 @@ def test_invoke_approval_request( messages = response.messages assert messages is not None - assert len(messages) == 3 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "approval_request_message" - assert messages[2].tool_call is not None - assert messages[2].tool_call.name == "get_secret_code_tool" - assert messages[2].tool_calls is not None - assert len(messages[2].tool_calls) == 1 - assert messages[2].tool_calls[0].name == "get_secret_code_tool" + assert messages[-1].message_type == "approval_request_message" + assert messages[-1].tool_call is not None + assert messages[-1].tool_call.name == "get_secret_code_tool" + assert messages[-1].tool_calls is not None + assert len(messages[-1].tool_calls) == 1 + assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool" # v3/v1 path: approval request tool args must not include request_heartbeat import json as _json - _args = _json.loads(messages[2].tool_call.arguments) + _args = _json.loads(messages[-1].tool_call.arguments) assert "request_heartbeat" not in _args - client.agents.retrieve(agent_id=agent.id) + client.agents.context.retrieve(agent_id=agent.id) - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id) def test_invoke_approval_request_stream( client: Letta, agent: AgentState, ) -> None: - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True, @@ -292,35 +289,32 @@ def test_invoke_approval_request_stream( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 5 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "approval_request_message" - assert messages[2].tool_call is not None - assert messages[2].tool_call.name == "get_secret_code_tool" - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" + assert messages[-3].message_type == "approval_request_message" + assert messages[-3].tool_call is not None + assert messages[-3].tool_call.name == "get_secret_code_tool" + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" - client.agents.retrieve(agent_id=agent.id) + client.agents.context.retrieve(agent_id=agent.id) - approve_tool_call(client, agent.id, messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, messages[-3].tool_call.tool_call_id) def test_invoke_tool_after_turning_off_requires_approval( client: Letta, agent: AgentState, - approval_tool_fixture, + approval_tool_fixture: Tool, ) -> None: response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=False, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) approvals=[ @@ -336,13 +330,13 @@ def test_invoke_tool_after_turning_off_requires_approval( ) messages = accumulate_chunks(response) - client.agents.tools.update_approval( + client.agents.tools.modify_approval( agent_id=agent.id, tool_name=approval_tool_fixture.name, - body_requires_approval=False, + requires_approval=False, ) - response = client.agents.messages.stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True) + response = client.agents.messages.create_stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True) messages = accumulate_chunks(response) @@ -389,12 +383,12 @@ def test_approve_tool_call_request( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=False, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) approvals=[ @@ -412,55 +406,38 @@ def test_approve_tool_call_request( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 3 or len(messages) == 5 or len(messages) == 6 assert messages[0].message_type == "tool_return_message" assert messages[0].tool_call_id == tool_call_id assert messages[0].status == "success" - if len(messages) == 4: - assert messages[1].message_type == "stop_reason" - assert messages[2].message_type == "usage_statistics" - elif len(messages) == 5: - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" - elif len(messages) == 6: - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "tool_call_message" - assert messages[3].message_type == "tool_return_message" - assert messages[4].message_type == "stop_reason" - assert messages[5].message_type == "usage_statistics" + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_approve_cursor_fetch( client: Letta, agent: AgentState, ) -> None: - last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id + last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) last_message_id = response.messages[0].id - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - messages_page = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - messages = messages_page.items - assert len(messages) == 4 + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) assert messages[0].message_type == "user_message" - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert messages[3].message_type == "approval_request_message" + assert messages[-1].message_type == "approval_request_message" # Ensure no request_heartbeat on approval request import json as _json - _args = _json.loads(messages[3].tool_call.arguments) + _args = _json.loads(messages[-1].tool_call.arguments) assert "request_heartbeat" not in _args client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=False, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) approvals=[ @@ -474,19 +451,14 @@ def test_approve_cursor_fetch( ], ) - messages_page = client.agents.messages.list(agent_id=agent.id, after=last_message_id) - messages = messages_page.items - assert len(messages) == 2 or len(messages) == 4 + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) assert messages[0].message_type == "approval_response_message" assert messages[0].approval_request_id == tool_call_id assert messages[0].approve is True - assert messages[0].approvals[0].approve is True - assert messages[0].approvals[0].tool_call_id == tool_call_id + assert messages[0].approvals[0]["approve"] is True + assert messages[0].approvals[0]["tool_call_id"] == tool_call_id assert messages[1].message_type == "tool_return_message" assert messages[1].status == "success" - if len(messages) == 4: - assert messages[2].message_type == "reasoning_message" - assert messages[3].message_type == "assistant_message" def test_approve_with_context_check( @@ -497,12 +469,12 @@ def test_approve_with_context_check( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=False, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) approvals=[ @@ -520,7 +492,7 @@ def test_approve_with_context_check( messages = accumulate_chunks(response) try: - client.agents.retrieve(agent_id=agent.id) + client.agents.context.retrieve(agent_id=agent.id) except Exception as e: if len(messages) > 4: raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") @@ -535,12 +507,12 @@ def test_approve_and_follow_up( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=False, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) approvals=[ @@ -554,7 +526,7 @@ def test_approve_and_follow_up( ], ) - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=USER_MESSAGE_FOLLOW_UP, stream_tokens=True, @@ -562,19 +534,65 @@ def test_approve_and_follow_up( messages = accumulate_chunks(response) + assert messages is not None + assert messages[0].message_type in ["reasoning_message", "assistant_message", "tool_call_message"] + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" + + +def test_approve_and_follow_up_with_error( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[-1].tool_call.tool_call_id + + # Mock the streaming adapter to return llm invocation failure on the follow up turn + with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], + ), + ], + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0] + assert stop_reason_message + assert stop_reason_message.stop_reason == "invalid_llm_response" + + # Ensure that agent is not bricked + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + ) + + messages = accumulate_chunks(response) + assert messages is not None assert len(messages) == 4 or len(messages) == 5 + assert messages[0].message_type == "reasoning_message" if len(messages) == 4: - assert messages[0].message_type == "reasoning_message" assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" - elif len(messages) == 5: - assert messages[0].message_type == "reasoning_message" + else: assert messages[1].message_type == "tool_call_message" assert messages[2].message_type == "tool_return_message" - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" # ------------------------------ @@ -590,12 +608,12 @@ def test_deny_tool_call_request( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy @@ -614,37 +632,28 @@ def test_deny_tool_call_request( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 5 - assert messages[0].message_type == "tool_return_message" - assert messages[0].tool_call_id == tool_call_id - assert messages[0].status == "error" - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert SECRET_CODE in messages[2].content - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" + if messages[0].message_type == "assistant_message": + assert SECRET_CODE in messages[0].content + elif messages[1].message_type == "assistant_message": + assert SECRET_CODE in messages[1].content def test_deny_cursor_fetch( client: Letta, agent: AgentState, ) -> None: - last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id + last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) last_message_id = response.messages[0].id - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - messages_page = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - messages = messages_page.items - assert len(messages) == 4 + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) assert messages[0].message_type == "user_message" - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert messages[3].message_type == "approval_request_message" - assert messages[3].tool_call.tool_call_id == tool_call_id + assert messages[-1].message_type == "approval_request_message" + assert messages[-1].tool_call.tool_call_id == tool_call_id # Ensure no request_heartbeat on approval request # import json as _json @@ -654,7 +663,7 @@ def test_deny_cursor_fetch( client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy @@ -670,17 +679,13 @@ def test_deny_cursor_fetch( ], ) - messages_page = client.agents.messages.list(agent_id=agent.id, after=last_message_id) - messages = messages_page.items - assert len(messages) == 4 + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) assert messages[0].message_type == "approval_response_message" - assert messages[0].approvals[0].approve == False - assert messages[0].approvals[0].tool_call_id == tool_call_id - assert messages[0].approvals[0].reason == f"You don't need to call the tool, the secret code is {SECRET_CODE}" + assert messages[0].approvals[0]["approve"] == False + assert messages[0].approvals[0]["tool_call_id"] == tool_call_id + assert messages[0].approvals[0]["reason"] == f"You don't need to call the tool, the secret code is {SECRET_CODE}" assert messages[1].message_type == "tool_return_message" assert messages[1].status == "error" - assert messages[2].message_type == "reasoning_message" - assert messages[3].message_type == "assistant_message" def test_deny_with_context_check( @@ -691,12 +696,12 @@ def test_deny_with_context_check( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy @@ -716,7 +721,7 @@ def test_deny_with_context_check( messages = accumulate_chunks(response) try: - client.agents.retrieve(agent_id=agent.id) + client.agents.context.retrieve(agent_id=agent.id) except Exception as e: if len(messages) > 4: raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") @@ -731,12 +736,12 @@ def test_deny_and_follow_up( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy @@ -752,7 +757,7 @@ def test_deny_and_follow_up( ], ) - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=USER_MESSAGE_FOLLOW_UP, stream_tokens=True, @@ -761,58 +766,62 @@ def test_deny_and_follow_up( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 4 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" -def test_agent_records_last_stop_reason_after_approval_flow( +def test_deny_and_follow_up_with_error( client: Letta, agent: AgentState, ) -> None: - """ - Test that the agent's last_stop_reason is properly updated after a human-in-the-loop flow. - This verifies the integration between run completion and agent state updates. - """ - # Get initial agent state - initial_agent = client.agents.retrieve(agent_id=agent.id) - initial_stop_reason = initial_agent.last_stop_reason - - # Trigger approval request response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) + tool_call_id = response.messages[-1].tool_call.tool_call_id + + # Mock the streaming adapter to return llm invocation failure on the follow up turn + with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "approval", + "approve": False, + "tool_call_id": tool_call_id, + "reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}", + }, + ], + ), + ], + stream_tokens=True, + ) + + messages = accumulate_chunks(response) - # Verify we got an approval request - messages = response.messages assert messages is not None - assert len(messages) == 3 - assert messages[2].message_type == "approval_request_message" + stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0] + assert stop_reason_message + assert stop_reason_message.stop_reason == "invalid_llm_response" - # Check agent after approval request (run should be paused with requires_approval) - agent_after_request = client.agents.retrieve(agent_id=agent.id) - assert agent_after_request.last_stop_reason == "requires_approval" - - # Approve the tool call - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) - - # Check agent after approval (run should complete with end_turn or similar) - agent_after_approval = client.agents.retrieve(agent_id=agent.id) - assert agent_after_approval.last_stop_reason is not None - assert agent_after_approval.last_stop_reason != initial_stop_reason - - # Send follow-up message to complete the flow - response2 = client.agents.messages.create( + # Ensure that agent is not bricked + response = client.agents.messages.create_stream( agent_id=agent.id, messages=USER_MESSAGE_FOLLOW_UP, ) - # Verify final agent state has the most recent stop reason - final_agent = client.agents.retrieve(agent_id=agent.id) - assert final_agent.last_stop_reason is not None + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" # -------------------------------- @@ -828,12 +837,12 @@ def test_client_side_tool_call_request( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy @@ -852,38 +861,30 @@ def test_client_side_tool_call_request( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 5 - assert messages[0].message_type == "tool_return_message" - assert messages[0].tool_call_id == tool_call_id - assert messages[0].status == "success" - assert messages[0].tool_return == SECRET_CODE - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert SECRET_CODE in messages[2].content - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" + if messages[0].message_type == "assistant_message": + assert SECRET_CODE in messages[1].content + elif messages[1].message_type == "assistant_message": + assert SECRET_CODE in messages[2].content + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_client_side_tool_call_cursor_fetch( client: Letta, agent: AgentState, ) -> None: - last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id + last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) last_message_id = response.messages[0].id - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - messages_page = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - messages = messages_page.items - assert len(messages) == 4 + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) assert messages[0].message_type == "user_message" - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert messages[3].message_type == "approval_request_message" - assert messages[3].tool_call.tool_call_id == tool_call_id + assert messages[-1].message_type == "approval_request_message" + assert messages[-1].tool_call.tool_call_id == tool_call_id # Ensure no request_heartbeat on approval request # import json as _json @@ -893,7 +894,7 @@ def test_client_side_tool_call_cursor_fetch( client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy @@ -909,20 +910,16 @@ def test_client_side_tool_call_cursor_fetch( ], ) - messages_page = client.agents.messages.list(agent_id=agent.id, after=last_message_id) - messages = messages_page.items - assert len(messages) == 4 + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) assert messages[0].message_type == "approval_response_message" - assert messages[0].approvals[0].type == "tool" - assert messages[0].approvals[0].tool_call_id == tool_call_id - assert messages[0].approvals[0].tool_return == SECRET_CODE - assert messages[0].approvals[0].status == "success" + assert messages[0].approvals[0]["type"] == "tool" + assert messages[0].approvals[0]["tool_call_id"] == tool_call_id + assert messages[0].approvals[0]["tool_return"] == SECRET_CODE + assert messages[0].approvals[0]["status"] == "success" assert messages[1].message_type == "tool_return_message" assert messages[1].status == "success" assert messages[1].tool_call_id == tool_call_id assert messages[1].tool_return == SECRET_CODE - assert messages[2].message_type == "reasoning_message" - assert messages[3].message_type == "assistant_message" def test_client_side_tool_call_with_context_check( @@ -933,12 +930,12 @@ def test_client_side_tool_call_with_context_check( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy @@ -958,7 +955,7 @@ def test_client_side_tool_call_with_context_check( messages = accumulate_chunks(response) try: - client.agents.retrieve(agent_id=agent.id) + client.agents.context.retrieve(agent_id=agent.id) except Exception as e: if len(messages) > 4: raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") @@ -973,12 +970,12 @@ def test_client_side_tool_call_and_follow_up( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=True, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy @@ -994,7 +991,7 @@ def test_client_side_tool_call_and_follow_up( ], ) - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=USER_MESSAGE_FOLLOW_UP, stream_tokens=True, @@ -1003,29 +1000,69 @@ def test_client_side_tool_call_and_follow_up( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 4 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" + + +def test_client_side_tool_call_and_follow_up_with_error( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[-1].tool_call.tool_call_id + + # Mock the streaming adapter to return llm invocation failure on the follow up turn + with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "tool", + "tool_call_id": tool_call_id, + "tool_return": SECRET_CODE, + "status": "success", + }, + ], + ), + ], + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0] + assert stop_reason_message + assert stop_reason_message.stop_reason == "invalid_llm_response" + + # Ensure that agent is not bricked + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_parallel_tool_calling( client: Letta, agent: AgentState, ) -> None: - # Parallel tool calling only works for Anthropic models - retrieved_agent = client.agents.retrieve(agent_id=agent.id) - model = None - if hasattr(retrieved_agent, "llm_config") and retrieved_agent.llm_config and hasattr(retrieved_agent.llm_config, "model"): - model = retrieved_agent.llm_config.model - elif hasattr(retrieved_agent, "model") and retrieved_agent.model: - model = retrieved_agent.model - - if not model or not model.startswith("anthropic/"): - pytest.skip("Parallel tool calling test only applies to Anthropic models.") - - last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id + last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id response = client.agents.messages.create( agent_id=agent.id, messages=USER_MESSAGE_PARALLEL_TOOL_CALL, @@ -1034,121 +1071,34 @@ def test_parallel_tool_calling( messages = response.messages assert messages is not None - assert len(messages) == 3 or len(messages) == 4 - assert messages[0].message_type == "reasoning_message" + assert messages[-2].message_type == "tool_call_message" + assert len(messages[-2].tool_calls) == 1 + assert messages[-2].tool_calls[0]["name"] == "roll_dice_tool" + assert "6" in messages[-2].tool_calls[0]["arguments"] + dice_tool_call_id = messages[-2].tool_calls[0]["tool_call_id"] - # Handle cases where assistant_message might be missing - idx = 1 - if len(messages) == 4: - # If 4 messages, expect assistant_message, tool_call_message, approval_request_message - assert messages[1].message_type == "assistant_message" - idx = 2 - else: - # If 3 messages, might skip assistant_message and go straight to tool_call_message - pass + assert messages[-1].message_type == "approval_request_message" + assert messages[-1].tool_call is not None + assert messages[-1].tool_call.name == "get_secret_code_tool" - # Find the tool_call_message and approval_request_message dynamically - # roll_dice_tool doesn't require approval, so it might be executed immediately - # and appear as tool_return_message instead of tool_call_message - tool_call_msg_idx = None - approval_request_msg_idx = None - tool_return_msg_idx = None - dice_tool_call_id = None - - for i, msg in enumerate(messages): - if msg.message_type == "tool_call_message": - tool_call_msg_idx = i - elif msg.message_type == "approval_request_message": - approval_request_msg_idx = i - elif msg.message_type == "tool_return_message": - tool_return_msg_idx = i - - assert approval_request_msg_idx is not None, f"Expected approval_request_message. Message types: {[m.message_type for m in messages]}" - - # Try to find roll_dice_tool - it could be in tool_call_message or already executed in tool_return_message - if tool_call_msg_idx is not None: - # Check if tool_call_message has roll_dice_tool - tool_calls = messages[tool_call_msg_idx].tool_calls - for tool_call in tool_calls: - if tool_call.name == "roll_dice_tool": - assert "6" in tool_call.arguments - dice_tool_call_id = tool_call.tool_call_id - break - - # If we didn't find it in tool_call_message, check tool_return_message - if dice_tool_call_id is None and tool_return_msg_idx is not None: - tool_return_msg = messages[tool_return_msg_idx] - if hasattr(tool_return_msg, "tool_call_id") and tool_return_msg.tool_call_id: - dice_tool_call_id = tool_return_msg.tool_call_id - - # If still not found, check if roll_dice_tool is in approval_request_message's tool_calls - if dice_tool_call_id is None and approval_request_msg_idx is not None: - approval_msg = messages[approval_request_msg_idx] - if hasattr(approval_msg, "tool_calls") and approval_msg.tool_calls: - for tool_call in approval_msg.tool_calls: - if tool_call.name == "roll_dice_tool": - assert "6" in tool_call.arguments - dice_tool_call_id = tool_call.tool_call_id - break - - # Get approval_request_message tool calls - approval_msg = messages[approval_request_msg_idx] - assert approval_msg.tool_call is not None - assert approval_msg.tool_call.name == "get_secret_code_tool" - - # Find the 3 get_secret_code_tool calls (might also have roll_dice_tool if combined) - get_secret_code_calls = [] - for tool_call in approval_msg.tool_calls: - if tool_call.name == "get_secret_code_tool": - get_secret_code_calls.append(tool_call) - elif tool_call.name == "roll_dice_tool" and dice_tool_call_id is None: - # Found roll_dice_tool in approval_request_message - assert "6" in tool_call.arguments - dice_tool_call_id = tool_call.tool_call_id - - assert len(get_secret_code_calls) == 3, ( - f"Expected 3 get_secret_code_tool calls, found {len(get_secret_code_calls)}. All tool calls: {[tc.name for tc in approval_msg.tool_calls]}" - ) - - assert "hello world" in get_secret_code_calls[0].arguments - approve_tool_call_id = get_secret_code_calls[0].tool_call_id - assert "hello letta" in get_secret_code_calls[1].arguments - deny_tool_call_id = get_secret_code_calls[1].tool_call_id - assert "hello test" in get_secret_code_calls[2].arguments - client_side_tool_call_id = get_secret_code_calls[2].tool_call_id - - # If we still don't have dice_tool_call_id, get it from DB messages - if dice_tool_call_id is None: - db_messages_page = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - db_messages = db_messages_page.items - # Look for tool_call_message or tool_return_message with roll_dice_tool - for db_msg in db_messages: - if db_msg.message_type == "tool_call_message" and hasattr(db_msg, "tool_calls") and db_msg.tool_calls: - for tool_call in db_msg.tool_calls: - if tool_call.name == "roll_dice_tool": - dice_tool_call_id = tool_call.tool_call_id - break - if dice_tool_call_id: - break - elif db_msg.message_type == "tool_return_message" and hasattr(db_msg, "tool_call_id") and db_msg.tool_call_id: - # Check if this might be roll_dice_tool by checking nearby messages - if dice_tool_call_id is None and len(db_messages) > 2: - potential_id = db_msg.tool_call_id - dice_tool_call_id = potential_id - break - - # Ensure we have dice_tool_call_id before proceeding - assert dice_tool_call_id is not None, ( - f"Could not find roll_dice_tool call_id. Message types in initial response: {[m.message_type for m in messages]}" - ) + assert len(messages[-1].tool_calls) == 3 + assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool" + assert "hello world" in messages[-1].tool_calls[0]["arguments"] + approve_tool_call_id = messages[-1].tool_calls[0]["tool_call_id"] + assert messages[-1].tool_calls[1]["name"] == "get_secret_code_tool" + assert "hello letta" in messages[-1].tool_calls[1]["arguments"] + deny_tool_call_id = messages[-1].tool_calls[1]["tool_call_id"] + assert messages[-1].tool_calls[2]["name"] == "get_secret_code_tool" + assert "hello test" in messages[-1].tool_calls[2]["arguments"] + client_side_tool_call_id = messages[-1].tool_calls[2]["tool_call_id"] # ensure context is not bricked - client.agents.retrieve(agent_id=agent.id) + client.agents.context.retrieve(agent_id=agent.id) response = client.agents.messages.create( agent_id=agent.id, messages=[ - ApprovalCreateParam( + ApprovalCreate( approve=False, # legacy (passing incorrect value to ensure it is overridden) approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) approvals=[ @@ -1180,16 +1130,16 @@ def test_parallel_tool_calling( assert messages[0].message_type == "tool_return_message" assert len(messages[0].tool_returns) == 4 for tool_return in messages[0].tool_returns: - if tool_return.tool_call_id == approve_tool_call_id: - assert tool_return.status == "success" - elif tool_return.tool_call_id == deny_tool_call_id: - assert tool_return.status == "error" - elif tool_return.tool_call_id == client_side_tool_call_id: - assert tool_return.status == "success" - assert tool_return.tool_return == SECRET_CODE + if tool_return["tool_call_id"] == approve_tool_call_id: + assert tool_return["status"] == "success" + elif tool_return["tool_call_id"] == deny_tool_call_id: + assert tool_return["status"] == "error" + elif tool_return["tool_call_id"] == client_side_tool_call_id: + assert tool_return["status"] == "success" + assert tool_return["tool_return"] == SECRET_CODE else: - assert tool_return.tool_call_id == dice_tool_call_id - assert tool_return.status == "success" + assert tool_return["tool_call_id"] == dice_tool_call_id + assert tool_return["status"] == "success" if len(messages) == 3: assert messages[1].message_type == "reasoning_message" assert messages[2].message_type == "assistant_message" @@ -1199,10 +1149,9 @@ def test_parallel_tool_calling( assert messages[3].message_type == "tool_return_message" # ensure context is not bricked - client.agents.retrieve(agent_id=agent.id) + client.agents.context.retrieve(agent_id=agent.id) - messages_page = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - messages = messages_page.items + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) assert len(messages) > 6 assert messages[0].message_type == "user_message" assert messages[1].message_type == "reasoning_message" @@ -1212,7 +1161,7 @@ def test_parallel_tool_calling( assert messages[5].message_type == "approval_response_message" assert messages[6].message_type == "tool_return_message" - response = client.agents.messages.stream( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=USER_MESSAGE_FOLLOW_UP, stream_tokens=True, @@ -1249,20 +1198,20 @@ def test_agent_records_last_stop_reason_after_approval_flow( # Verify we got an approval request messages = response.messages assert messages is not None - assert len(messages) == 3 - assert messages[2].message_type == "approval_request_message" + assert messages[-1].message_type == "approval_request_message" # Check agent after approval request (run should be paused with requires_approval) agent_after_request = client.agents.retrieve(agent_id=agent.id) assert agent_after_request.last_stop_reason == "requires_approval" # Approve the tool call - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id) # Check agent after approval (run should complete with end_turn or similar) agent_after_approval = client.agents.retrieve(agent_id=agent.id) + # After approval and run completion, stop reason should be updated (could be end_turn or other terminal reason) assert agent_after_approval.last_stop_reason is not None - assert agent_after_approval.last_stop_reason != initial_stop_reason + assert agent_after_approval.last_stop_reason != initial_stop_reason # Should be different from initial # Send follow-up message to complete the flow response2 = client.agents.messages.create(