feat: amend hitl tests for new agent loop (#5176)
This commit is contained in:
@@ -11,6 +11,7 @@ from dotenv import load_dotenv
|
||||
from letta_client import AgentState, ApprovalCreate, Letta, MessageCreate, Tool
|
||||
from letta_client.core.api_error import ApiError
|
||||
|
||||
from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter
|
||||
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import AgentType
|
||||
@@ -22,7 +23,7 @@ logger = get_logger(__name__)
|
||||
# ------------------------------
|
||||
|
||||
USER_MESSAGE_OTID = str(uuid.uuid4())
|
||||
USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'. Make sure to set request_heartbeat to True."
|
||||
USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'."
|
||||
USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
@@ -137,12 +138,11 @@ def agent(client: Letta, approval_tool_fixture) -> AgentState:
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is configured with the requires_approval_tool.
|
||||
"""
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
agent_state = client.agents.create(
|
||||
name="approval_test_agent",
|
||||
agent_type=AgentType.letta_v1_agent,
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id, approval_tool_fixture.id],
|
||||
tool_ids=[approval_tool_fixture.id],
|
||||
model="anthropic/claude-sonnet-4-20250514",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tags=["approval_test"],
|
||||
@@ -214,10 +214,10 @@ def test_send_message_with_requires_approval_tool(
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "approval_request_message"
|
||||
# v3/v1 path: approval request tool args must not include request_heartbeat
|
||||
import json as _json
|
||||
# import json as _json
|
||||
|
||||
_args = _json.loads(messages[2].tool_call.arguments)
|
||||
assert "request_heartbeat" not in _args
|
||||
# _args = _json.loads(messages[2].tool_call.arguments)
|
||||
# assert "request_heartbeat" not in _args
|
||||
assert messages[3].message_type == "stop_reason"
|
||||
assert messages[4].message_type == "usage_statistics"
|
||||
|
||||
@@ -254,14 +254,18 @@ def test_send_message_after_turning_off_requires_approval(
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) == 6 or len(messages) == 8
|
||||
assert len(messages) == 6 or len(messages) == 8 or len(messages) == 9
|
||||
assert messages[0].message_type == "reasoning_message"
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "tool_call_message"
|
||||
assert messages[3].message_type == "tool_return_message"
|
||||
if len(messages) > 6:
|
||||
if len(messages) == 8:
|
||||
assert messages[4].message_type == "reasoning_message"
|
||||
assert messages[5].message_type == "assistant_message"
|
||||
elif len(messages) == 9:
|
||||
assert messages[4].message_type == "reasoning_message"
|
||||
assert messages[5].message_type == "tool_call_message"
|
||||
assert messages[6].message_type == "tool_return_message"
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@@ -280,10 +284,10 @@ def test_approve_tool_call_request(
|
||||
approval_request_id = response.messages[0].id
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
# Ensure no request_heartbeat on approval request
|
||||
import json as _json
|
||||
# import json as _json
|
||||
|
||||
_args = _json.loads(response.messages[0].tool_call.arguments)
|
||||
assert "request_heartbeat" not in _args
|
||||
# _args = _json.loads(response.messages[0].tool_call.arguments)
|
||||
# assert "request_heartbeat" not in _args
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
@@ -419,8 +423,8 @@ def test_approve_and_follow_up_with_error(
|
||||
)
|
||||
approval_request_id = response.messages[0].id
|
||||
|
||||
# Mock the streaming interface to return no tool call on the follow up request heartbeat message
|
||||
with patch.object(AnthropicStreamingInterface, "get_tool_call_object", return_value=None):
|
||||
# Mock the streaming adapter to return llm invocation failure on the follow up turn
|
||||
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
@@ -437,7 +441,7 @@ def test_approve_and_follow_up_with_error(
|
||||
assert messages is not None
|
||||
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
|
||||
assert stop_reason_message
|
||||
assert stop_reason_message.stop_reason == "no_tool_call"
|
||||
assert stop_reason_message.stop_reason == "invalid_llm_response"
|
||||
|
||||
# Ensure that agent is not bricked
|
||||
response = client.agents.messages.create_stream(
|
||||
@@ -448,11 +452,13 @@ def test_approve_and_follow_up_with_error(
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) == 4
|
||||
assert len(messages) == 4 or len(messages) == 5
|
||||
assert messages[0].message_type == "reasoning_message"
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "stop_reason"
|
||||
assert messages[3].message_type == "usage_statistics"
|
||||
if len(messages) == 4:
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
else:
|
||||
assert messages[1].message_type == "tool_call_message"
|
||||
assert messages[2].message_type == "tool_return_message"
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@@ -469,7 +475,7 @@ def test_deny_tool_call_request(
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
approval_request_id = response.messages[0].id
|
||||
tool_call_id = response.messages[1].tool_call.tool_call_id
|
||||
tool_call_id = response.messages[2].tool_call.tool_call_id
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
@@ -508,16 +514,17 @@ def test_deny_cursor_fetch(
|
||||
approval_request_id = response.messages[0].id
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
|
||||
assert len(messages) == 3
|
||||
assert len(messages) == 4
|
||||
assert messages[0].message_type == "user_message"
|
||||
assert messages[1].message_type == "reasoning_message"
|
||||
assert messages[2].message_type == "approval_request_message"
|
||||
assert messages[2].id == approval_request_id
|
||||
assert messages[2].message_type == "assistant_message"
|
||||
assert messages[3].message_type == "approval_request_message"
|
||||
assert messages[3].id == approval_request_id
|
||||
# Ensure no request_heartbeat on approval request
|
||||
import json as _json
|
||||
# import json as _json
|
||||
|
||||
_args = _json.loads(messages[2].tool_call.arguments)
|
||||
assert "request_heartbeat" not in _args
|
||||
# _args = _json.loads(messages[2].tool_call.arguments)
|
||||
# assert "request_heartbeat" not in _args
|
||||
|
||||
last_message_cursor = approval_request_id
|
||||
client.agents.messages.create(
|
||||
@@ -532,13 +539,12 @@ def test_deny_cursor_fetch(
|
||||
)
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
|
||||
assert len(messages) == 5
|
||||
assert len(messages) == 4
|
||||
assert messages[0].message_type == "approval_response_message"
|
||||
assert messages[1].message_type == "tool_return_message"
|
||||
assert messages[1].status == "error"
|
||||
assert messages[2].message_type == "user_message" # heartbeat
|
||||
assert messages[3].message_type == "reasoning_message"
|
||||
assert messages[4].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "reasoning_message"
|
||||
assert messages[3].message_type == "assistant_message"
|
||||
|
||||
|
||||
def test_deny_and_follow_up(
|
||||
@@ -588,8 +594,8 @@ def test_deny_and_follow_up_with_error(
|
||||
)
|
||||
approval_request_id = response.messages[0].id
|
||||
|
||||
# Mock the streaming interface to return no tool call on the follow up request heartbeat message
|
||||
with patch.object(AnthropicStreamingInterface, "get_tool_call_object", return_value=None):
|
||||
# Mock the streaming adapter to return llm invocation failure on the follow up turn
|
||||
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
@@ -607,7 +613,7 @@ def test_deny_and_follow_up_with_error(
|
||||
assert messages is not None
|
||||
stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0]
|
||||
assert stop_reason_message
|
||||
assert stop_reason_message.stop_reason == "no_tool_call"
|
||||
assert stop_reason_message.stop_reason == "invalid_llm_response"
|
||||
|
||||
# Ensure that agent is not bricked
|
||||
response = client.agents.messages.create_stream(
|
||||
|
||||
Reference in New Issue
Block a user