diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 0e142818..3612eec4 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -320,6 +320,8 @@ class LettaAgentV3(LettaAgentV2): # Raise if no chunks sent yet (response not started, can return error status code) raise else: + yield f"data: {self.stop_reason.model_dump_json()}\n\n" + # Mid-stream error: yield error event to client in SSE format error_chunk = { "error": { diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index b4d0b067..1a343253 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -246,7 +246,7 @@ def test_send_user_message_with_pending_request(client, agent): messages=[MessageCreate(role="user", content="hi")], ) - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id) def test_send_approval_message_with_incorrect_request_id(client, agent): @@ -273,7 +273,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): ], ) - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id) # ------------------------------ @@ -293,25 +293,22 @@ def test_invoke_approval_request( messages = response.messages assert messages is not None - assert len(messages) == 3 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "approval_request_message" - assert messages[2].tool_call is not None - assert messages[2].tool_call.name == "get_secret_code_tool" - assert messages[2].tool_calls is not None - assert len(messages[2].tool_calls) == 1 - assert messages[2].tool_calls[0]["name"] == "get_secret_code_tool" + assert messages[-1].message_type == "approval_request_message" + assert messages[-1].tool_call is not None + assert messages[-1].tool_call.name == "get_secret_code_tool" + assert messages[-1].tool_calls is not None + assert len(messages[-1].tool_calls) == 1 + assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool" # v3/v1 path: approval request tool args must not include request_heartbeat import json as _json - _args = _json.loads(messages[2].tool_call.arguments) + _args = _json.loads(messages[-1].tool_call.arguments) assert "request_heartbeat" not in _args client.agents.context.retrieve(agent_id=agent.id) - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id) def test_invoke_approval_request_stream( @@ -327,18 +324,15 @@ def test_invoke_approval_request_stream( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 5 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "approval_request_message" - assert messages[2].tool_call is not None - assert messages[2].tool_call.name == "get_secret_code_tool" - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" + assert messages[-3].message_type == "approval_request_message" + assert messages[-3].tool_call is not None + assert messages[-3].tool_call.name == "get_secret_code_tool" + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" client.agents.context.retrieve(agent_id=agent.id) - approve_tool_call(client, agent.id, messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, messages[-3].tool_call.tool_call_id) def test_invoke_tool_after_turning_off_requires_approval( @@ -350,7 +344,7 @@ def test_invoke_tool_after_turning_off_requires_approval( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, @@ -424,7 +418,7 @@ def test_approve_tool_call_request( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, @@ -477,18 +471,16 @@ def test_approve_cursor_fetch( messages=USER_MESSAGE_TEST_APPROVAL, ) last_message_id = response.messages[0].id - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) assert len(messages) == 4 assert messages[0].message_type == "user_message" - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert messages[3].message_type == "approval_request_message" + assert messages[-1].message_type == "approval_request_message" # Ensure no request_heartbeat on approval request import json as _json - _args = _json.loads(messages[3].tool_call.arguments) + _args = _json.loads(messages[-1].tool_call.arguments) assert "request_heartbeat" not in _args client.agents.messages.create( @@ -509,7 +501,6 @@ def test_approve_cursor_fetch( ) messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) - assert len(messages) == 2 or len(messages) == 4 assert messages[0].message_type == "approval_response_message" assert messages[0].approval_request_id == tool_call_id assert messages[0].approve is True @@ -517,9 +508,6 @@ def test_approve_cursor_fetch( assert messages[0].approvals[0]["tool_call_id"] == tool_call_id assert messages[1].message_type == "tool_return_message" assert messages[1].status == "success" - if len(messages) == 4: - assert messages[2].message_type == "reasoning_message" - assert messages[3].message_type == "assistant_message" def test_approve_with_context_check( @@ -530,7 +518,7 @@ def test_approve_with_context_check( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, @@ -568,7 +556,7 @@ def test_approve_and_follow_up( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id client.agents.messages.create( agent_id=agent.id, @@ -618,7 +606,7 @@ def test_approve_and_follow_up_with_error( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id # Mock the streaming adapter to return llm invocation failure on the follow up turn with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): @@ -678,7 +666,7 @@ def test_deny_tool_call_request( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, @@ -702,15 +690,13 @@ def test_deny_tool_call_request( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 5 assert messages[0].message_type == "tool_return_message" assert messages[0].tool_call_id == tool_call_id assert messages[0].status == "error" - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert SECRET_CODE in messages[2].content - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" + if messages[1].message_type == "assistant_message": + assert SECRET_CODE in messages[1].content + elif messages[2].message_type == "assistant_message": + assert SECRET_CODE in messages[2].content def test_deny_cursor_fetch( @@ -723,15 +709,12 @@ def test_deny_cursor_fetch( messages=USER_MESSAGE_TEST_APPROVAL, ) last_message_id = response.messages[0].id - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - assert len(messages) == 4 assert messages[0].message_type == "user_message" - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert messages[3].message_type == "approval_request_message" - assert messages[3].tool_call.tool_call_id == tool_call_id + assert messages[-1].message_type == "approval_request_message" + assert messages[-1].tool_call.tool_call_id == tool_call_id # Ensure no request_heartbeat on approval request # import json as _json @@ -758,15 +741,12 @@ def test_deny_cursor_fetch( ) messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) - assert len(messages) == 4 assert messages[0].message_type == "approval_response_message" assert messages[0].approvals[0]["approve"] == False assert messages[0].approvals[0]["tool_call_id"] == tool_call_id assert messages[0].approvals[0]["reason"] == f"You don't need to call the tool, the secret code is {SECRET_CODE}" assert messages[1].message_type == "tool_return_message" assert messages[1].status == "error" - assert messages[2].message_type == "reasoning_message" - assert messages[3].message_type == "assistant_message" def test_deny_with_context_check( @@ -777,7 +757,7 @@ def test_deny_with_context_check( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, @@ -817,7 +797,7 @@ def test_deny_and_follow_up( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id client.agents.messages.create( agent_id=agent.id, @@ -847,11 +827,9 @@ def test_deny_and_follow_up( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 4 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_deny_and_follow_up_with_error( @@ -862,7 +840,7 @@ def test_deny_and_follow_up_with_error( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id # Mock the streaming adapter to return llm invocation failure on the follow up turn with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): @@ -902,11 +880,9 @@ def test_deny_and_follow_up_with_error( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 4 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" # -------------------------------- @@ -922,7 +898,7 @@ def test_client_side_tool_call_request( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, @@ -946,16 +922,16 @@ def test_client_side_tool_call_request( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 5 assert messages[0].message_type == "tool_return_message" assert messages[0].tool_call_id == tool_call_id assert messages[0].status == "success" assert messages[0].tool_return == SECRET_CODE - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert SECRET_CODE in messages[2].content - assert messages[3].message_type == "stop_reason" - assert messages[4].message_type == "usage_statistics" + if messages[1].message_type == "assistant_message": + assert SECRET_CODE in messages[1].content + elif messages[2].message_type == "assistant_message": + assert SECRET_CODE in messages[2].content + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_client_side_tool_call_cursor_fetch( @@ -968,15 +944,12 @@ def test_client_side_tool_call_cursor_fetch( messages=USER_MESSAGE_TEST_APPROVAL, ) last_message_id = response.messages[0].id - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) - assert len(messages) == 4 assert messages[0].message_type == "user_message" - assert messages[1].message_type == "reasoning_message" - assert messages[2].message_type == "assistant_message" - assert messages[3].message_type == "approval_request_message" - assert messages[3].tool_call.tool_call_id == tool_call_id + assert messages[-1].message_type == "approval_request_message" + assert messages[-1].tool_call.tool_call_id == tool_call_id # Ensure no request_heartbeat on approval request # import json as _json @@ -1003,7 +976,6 @@ def test_client_side_tool_call_cursor_fetch( ) messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) - assert len(messages) == 4 assert messages[0].message_type == "approval_response_message" assert messages[0].approvals[0]["type"] == "tool" assert messages[0].approvals[0]["tool_call_id"] == tool_call_id @@ -1013,8 +985,6 @@ def test_client_side_tool_call_cursor_fetch( assert messages[1].status == "success" assert messages[1].tool_call_id == tool_call_id assert messages[1].tool_return == SECRET_CODE - assert messages[2].message_type == "reasoning_message" - assert messages[3].message_type == "assistant_message" def test_client_side_tool_call_with_context_check( @@ -1025,7 +995,7 @@ def test_client_side_tool_call_with_context_check( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id response = client.agents.messages.create_stream( agent_id=agent.id, @@ -1065,7 +1035,7 @@ def test_client_side_tool_call_and_follow_up( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id client.agents.messages.create( agent_id=agent.id, @@ -1095,11 +1065,9 @@ def test_client_side_tool_call_and_follow_up( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 4 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_client_side_tool_call_and_follow_up_with_error( @@ -1110,7 +1078,7 @@ def test_client_side_tool_call_and_follow_up_with_error( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) - tool_call_id = response.messages[2].tool_call.tool_call_id + tool_call_id = response.messages[-1].tool_call.tool_call_id # Mock the streaming adapter to return llm invocation failure on the follow up turn with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): @@ -1150,11 +1118,9 @@ def test_client_side_tool_call_and_follow_up_with_error( messages = accumulate_chunks(response) assert messages is not None - assert len(messages) == 4 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "stop_reason" - assert messages[3].message_type == "usage_statistics" + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" def test_parallel_tool_calling( @@ -1170,29 +1136,26 @@ def test_parallel_tool_calling( messages = response.messages assert messages is not None - assert len(messages) == 4 - assert messages[0].message_type == "reasoning_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "tool_call_message" - assert len(messages[2].tool_calls) == 1 - assert messages[2].tool_calls[0]["name"] == "roll_dice_tool" - assert "6" in messages[2].tool_calls[0]["arguments"] - dice_tool_call_id = messages[2].tool_calls[0]["tool_call_id"] + assert messages[-2].message_type == "tool_call_message" + assert len(messages[-2].tool_calls) == 1 + assert messages[-2].tool_calls[0]["name"] == "roll_dice_tool" + assert "6" in messages[-2].tool_calls[0]["arguments"] + dice_tool_call_id = messages[-2].tool_calls[0]["tool_call_id"] - assert messages[3].message_type == "approval_request_message" - assert messages[3].tool_call is not None - assert messages[3].tool_call.name == "get_secret_code_tool" + assert messages[-1].message_type == "approval_request_message" + assert messages[-1].tool_call is not None + assert messages[-1].tool_call.name == "get_secret_code_tool" - assert len(messages[3].tool_calls) == 3 - assert messages[3].tool_calls[0]["name"] == "get_secret_code_tool" - assert "hello world" in messages[3].tool_calls[0]["arguments"] - approve_tool_call_id = messages[3].tool_calls[0]["tool_call_id"] - assert messages[3].tool_calls[1]["name"] == "get_secret_code_tool" - assert "hello letta" in messages[3].tool_calls[1]["arguments"] - deny_tool_call_id = messages[3].tool_calls[1]["tool_call_id"] - assert messages[3].tool_calls[2]["name"] == "get_secret_code_tool" - assert "hello test" in messages[3].tool_calls[2]["arguments"] - client_side_tool_call_id = messages[3].tool_calls[2]["tool_call_id"] + assert len(messages[-1].tool_calls) == 3 + assert messages[-1].tool_calls[0]["name"] == "get_secret_code_tool" + assert "hello world" in messages[-1].tool_calls[0]["arguments"] + approve_tool_call_id = messages[-1].tool_calls[0]["tool_call_id"] + assert messages[-1].tool_calls[1]["name"] == "get_secret_code_tool" + assert "hello letta" in messages[-1].tool_calls[1]["arguments"] + deny_tool_call_id = messages[-1].tool_calls[1]["tool_call_id"] + assert messages[-1].tool_calls[2]["name"] == "get_secret_code_tool" + assert "hello test" in messages[-1].tool_calls[2]["arguments"] + client_side_tool_call_id = messages[-1].tool_calls[2]["tool_call_id"] # ensure context is not bricked client.agents.context.retrieve(agent_id=agent.id) @@ -1300,15 +1263,14 @@ def test_agent_records_last_stop_reason_after_approval_flow( # Verify we got an approval request messages = response.messages assert messages is not None - assert len(messages) == 3 - assert messages[2].message_type == "approval_request_message" + assert messages[-1].message_type == "approval_request_message" # Check agent after approval request (run should be paused with requires_approval) agent_after_request = client.agents.retrieve(agent_id=agent.id) assert agent_after_request.last_stop_reason == "requires_approval" # Approve the tool call - approve_tool_call(client, agent.id, response.messages[2].tool_call.tool_call_id) + approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id) # Check agent after approval (run should complete with end_turn or similar) agent_after_approval = client.agents.retrieve(agent_id=agent.id)