diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 7648badd..7a50685b 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -11,6 +11,7 @@ from dotenv import load_dotenv from letta_client import AgentState, ApprovalCreate, Letta, MessageCreate, Tool from letta_client.core.api_error import ApiError +from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface from letta.log import get_logger from letta.schemas.enums import AgentType @@ -22,7 +23,7 @@ logger = get_logger(__name__) # ------------------------------ USER_MESSAGE_OTID = str(uuid.uuid4()) -USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'. Make sure to set request_heartbeat to True." +USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'." USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [ MessageCreate( role="user", @@ -137,12 +138,11 @@ def agent(client: Letta, approval_tool_fixture) -> AgentState: Creates and returns an agent state for testing with a pre-configured agent. The agent is configured with the requires_approval_tool. """ - send_message_tool = client.tools.list(name="send_message")[0] agent_state = client.agents.create( name="approval_test_agent", agent_type=AgentType.letta_v1_agent, include_base_tools=False, - tool_ids=[send_message_tool.id, approval_tool_fixture.id], + tool_ids=[approval_tool_fixture.id], model="anthropic/claude-sonnet-4-20250514", embedding="openai/text-embedding-3-small", tags=["approval_test"], @@ -214,10 +214,10 @@ def test_send_message_with_requires_approval_tool( assert messages[1].message_type == "assistant_message" assert messages[2].message_type == "approval_request_message" # v3/v1 path: approval request tool args must not include request_heartbeat - import json as _json + # import json as _json - _args = _json.loads(messages[2].tool_call.arguments) - assert "request_heartbeat" not in _args + # _args = _json.loads(messages[2].tool_call.arguments) + # assert "request_heartbeat" not in _args assert messages[3].message_type == "stop_reason" assert messages[4].message_type == "usage_statistics" @@ -254,14 +254,18 @@ def test_send_message_after_turning_off_requires_approval( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 6 or len(messages) == 8 + assert len(messages) == 6 or len(messages) == 8 or len(messages) == 9 assert messages[0].message_type == "reasoning_message" assert messages[1].message_type == "assistant_message" assert messages[2].message_type == "tool_call_message" assert messages[3].message_type == "tool_return_message" - if len(messages) > 6: + if len(messages) == 8: assert messages[4].message_type == "reasoning_message" assert messages[5].message_type == "assistant_message" + elif len(messages) == 9: + assert messages[4].message_type == "reasoning_message" + assert messages[5].message_type == "tool_call_message" + assert messages[6].message_type == "tool_return_message" # ------------------------------ @@ -280,10 +284,10 @@ def test_approve_tool_call_request( approval_request_id = response.messages[0].id tool_call_id = response.messages[2].tool_call.tool_call_id # Ensure no request_heartbeat on approval request - import json as _json + # import json as _json - _args = _json.loads(response.messages[0].tool_call.arguments) - assert "request_heartbeat" not in _args + # _args = _json.loads(response.messages[0].tool_call.arguments) + # assert "request_heartbeat" not in _args response = client.agents.messages.create_stream( agent_id=agent.id, @@ -419,8 +423,8 @@ def test_approve_and_follow_up_with_error( ) approval_request_id = response.messages[0].id - # Mock the streaming interface to return no tool call on the follow up request heartbeat message - with patch.object(AnthropicStreamingInterface, "get_tool_call_object", return_value=None): + # Mock the streaming adapter to return llm invocation failure on the follow up turn + with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ @@ -437,7 +441,7 @@ def test_approve_and_follow_up_with_error( assert messages is not None stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0] assert stop_reason_message - assert stop_reason_message.stop_reason == "no_tool_call" + assert stop_reason_message.stop_reason == "invalid_llm_response" # Ensure that agent is not bricked response = client.agents.messages.create_stream( @@ -448,11 +452,13 @@ def test_approve_and_follow_up_with_error( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 4 + assert len(messages) == 4 or len(messages) == 5 assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" + if len(messages) == 4: + assert messages[1].message_type == "assistant_message" + else: + assert messages[1].message_type == "tool_call_message" + assert messages[2].message_type == "tool_return_message" # ------------------------------ @@ -469,7 +475,7 @@ def test_deny_tool_call_request( messages=USER_MESSAGE_TEST_APPROVAL, ) approval_request_id = response.messages[0].id - tool_call_id = response.messages[1].tool_call.tool_call_id + tool_call_id = response.messages[2].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, @@ -508,16 +514,17 @@ def test_deny_cursor_fetch( approval_request_id = response.messages[0].id messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - assert len(messages) == 3 + assert len(messages) == 4 assert messages[0].message_type == "user_message" assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "approval_request_message" - assert messages[2].id == approval_request_id + assert messages[2].message_type == "assistant_message" + assert messages[3].message_type == "approval_request_message" + assert messages[3].id == approval_request_id # Ensure no request_heartbeat on approval request - import json as _json + # import json as _json - _args = _json.loads(messages[2].tool_call.arguments) - assert "request_heartbeat" not in _args + # _args = _json.loads(messages[2].tool_call.arguments) + # assert "request_heartbeat" not in _args last_message_cursor = approval_request_id client.agents.messages.create( @@ -532,13 +539,12 @@ def test_deny_cursor_fetch( ) messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - assert len(messages) == 5 + assert len(messages) == 4 assert messages[0].message_type == "approval_response_message" assert messages[1].message_type == "tool_return_message" assert messages[1].status == "error" - assert messages[2].message_type == "user_message" # heartbeat - assert messages[3].message_type == "reasoning_message" - assert messages[4].message_type == "assistant_message" + assert messages[2].message_type == "reasoning_message" + assert messages[3].message_type == "assistant_message" def test_deny_and_follow_up( @@ -588,8 +594,8 @@ def test_deny_and_follow_up_with_error( ) approval_request_id = response.messages[0].id - # Mock the streaming interface to return no tool call on the follow up request heartbeat message - with patch.object(AnthropicStreamingInterface, "get_tool_call_object", return_value=None): + # Mock the streaming adapter to return llm invocation failure on the follow up turn + with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ @@ -607,7 +613,7 @@ def test_deny_and_follow_up_with_error( assert messages is not None stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0] assert stop_reason_message - assert stop_reason_message.stop_reason == "no_tool_call" + assert stop_reason_message.stop_reason == "invalid_llm_response" # Ensure that agent is not bricked response = client.agents.messages.create_stream(