feat: amend hitl tests for new agent loop (#5176)

This commit is contained in:
cthomas
2025-10-06 13:11:33 -07:00
committed by Caren Thomas
parent 4089ad0629
commit a7fa6bb33f

View File

@@ -11,6 +11,7 @@ from dotenv import load_dotenv
from letta_client import AgentState, ApprovalCreate, Letta, MessageCreate, Tool
from letta_client.core.api_error import ApiError
from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.log import get_logger
from letta.schemas.enums import AgentType
@@ -22,7 +23,7 @@ logger = get_logger(__name__)
# ------------------------------
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'. Make sure to set request_heartbeat to True."
USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'."
USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [
MessageCreate(
role="user",
@@ -137,12 +138,11 @@ def agent(client: Letta, approval_tool_fixture) -> AgentState:
Creates and returns an agent state for testing with a pre-configured agent.
The agent is configured with the requires_approval_tool.
"""
send_message_tool = client.tools.list(name="send_message")[0]
agent_state = client.agents.create(
name="approval_test_agent",
agent_type=AgentType.letta_v1_agent,
include_base_tools=False,
tool_ids=[send_message_tool.id, approval_tool_fixture.id],
tool_ids=[approval_tool_fixture.id],
model="anthropic/claude-sonnet-4-20250514",
embedding="openai/text-embedding-3-small",
tags=["approval_test"],
@@ -214,10 +214,10 @@ def test_send_message_with_requires_approval_tool(
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "approval_request_message"
# v3/v1 path: approval request tool args must not include request_heartbeat
import json as _json
# import json as _json
_args = _json.loads(messages[2].tool_call.arguments)
assert "request_heartbeat" not in _args
# _args = _json.loads(messages[2].tool_call.arguments)
# assert "request_heartbeat" not in _args
assert messages[3].message_type == "stop_reason"
assert messages[4].message_type == "usage_statistics"
@@ -254,14 +254,18 @@ def test_send_message_after_turning_off_requires_approval(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 6 or len(messages) == 8
assert len(messages) == 6 or len(messages) == 8 or len(messages) == 9
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "tool_call_message"
assert messages[3].message_type == "tool_return_message"
if len(messages) > 6:
if len(messages) == 8:
assert messages[4].message_type == "reasoning_message"
assert messages[5].message_type == "assistant_message"
elif len(messages) == 9:
assert messages[4].message_type == "reasoning_message"
assert messages[5].message_type == "tool_call_message"
assert messages[6].message_type == "tool_return_message"
# ------------------------------
@@ -280,10 +284,10 @@ def test_approve_tool_call_request(
approval_request_id = response.messages[0].id
tool_call_id = response.messages[2].tool_call.tool_call_id
# Ensure no request_heartbeat on approval request
import json as _json
# import json as _json
_args = _json.loads(response.messages[0].tool_call.arguments)
assert "request_heartbeat" not in _args
# _args = _json.loads(response.messages[0].tool_call.arguments)
# assert "request_heartbeat" not in _args
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -419,8 +423,8 @@ def test_approve_and_follow_up_with_error(
)
approval_request_id = response.messages[0].id
# Mock the streaming interface to return no tool call on the follow up request heartbeat message
with patch.object(AnthropicStreamingInterface, "get_tool_call_object", return_value=None):
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
@@ -437,7 +441,7 @@ def test_approve_and_follow_up_with_error(
assert messages is not None
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "no_tool_call"
assert stop_reason_message.stop_reason == "invalid_llm_response"
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(
@@ -448,11 +452,13 @@ def test_approve_and_follow_up_with_error(
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert len(messages) == 4 or len(messages) == 5
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "stop_reason"
assert messages[3].message_type == "usage_statistics"
if len(messages) == 4:
assert messages[1].message_type == "assistant_message"
else:
assert messages[1].message_type == "tool_call_message"
assert messages[2].message_type == "tool_return_message"
# ------------------------------
@@ -469,7 +475,7 @@ def test_deny_tool_call_request(
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
tool_call_id = response.messages[1].tool_call.tool_call_id
tool_call_id = response.messages[2].tool_call.tool_call_id
response = client.agents.messages.create_stream(
agent_id=agent.id,
@@ -508,16 +514,17 @@ def test_deny_cursor_fetch(
approval_request_id = response.messages[0].id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 3
assert len(messages) == 4
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "approval_request_message"
assert messages[2].id == approval_request_id
assert messages[2].message_type == "assistant_message"
assert messages[3].message_type == "approval_request_message"
assert messages[3].id == approval_request_id
# Ensure no request_heartbeat on approval request
import json as _json
# import json as _json
_args = _json.loads(messages[2].tool_call.arguments)
assert "request_heartbeat" not in _args
# _args = _json.loads(messages[2].tool_call.arguments)
# assert "request_heartbeat" not in _args
last_message_cursor = approval_request_id
client.agents.messages.create(
@@ -532,13 +539,12 @@ def test_deny_cursor_fetch(
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 5
assert len(messages) == 4
assert messages[0].message_type == "approval_response_message"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "error"
assert messages[2].message_type == "user_message" # heartbeat
assert messages[3].message_type == "reasoning_message"
assert messages[4].message_type == "assistant_message"
assert messages[2].message_type == "reasoning_message"
assert messages[3].message_type == "assistant_message"
def test_deny_and_follow_up(
@@ -588,8 +594,8 @@ def test_deny_and_follow_up_with_error(
)
approval_request_id = response.messages[0].id
# Mock the streaming interface to return no tool call on the follow up request heartbeat message
with patch.object(AnthropicStreamingInterface, "get_tool_call_object", return_value=None):
# Mock the streaming adapter to return llm invocation failure on the follow up turn
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
@@ -607,7 +613,7 @@ def test_deny_and_follow_up_with_error(
assert messages is not None
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
assert stop_reason_message
assert stop_reason_message.stop_reason == "no_tool_call"
assert stop_reason_message.stop_reason == "invalid_llm_response"
# Ensure that agent is not bricked
response = client.agents.messages.create_stream(