From f3df0433aeb1902bc0bc05f6e9ad34ee463bda6b Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 1 Sep 2025 16:23:06 -0700 Subject: [PATCH] feat: add more tests for hitl (#4339) --- tests/integration_test_human_in_the_loop.py | 155 ++++++++++++++++++-- 1 file changed, 141 insertions(+), 14 deletions(-) diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 8e132a4b..5d7c495f 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -2,16 +2,15 @@ import os import threading import time import uuid -from typing import Any, List +from typing import List import pytest import requests from dotenv import load_dotenv -from letta_client import ApprovalCreate, Letta, MessageCreate +from letta_client import AgentState, ApprovalCreate, Letta, MessageCreate, Tool from letta_client.core.api_error import ApiError from letta.log import get_logger -from letta.schemas.agent import AgentState logger = get_logger(__name__) @@ -30,6 +29,15 @@ USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [ ] FAKE_REQUEST_ID = str(uuid.uuid4()) SECRET_CODE = str(740845635798344975) +USER_MESSAGE_FOLLOW_UP_OTID = str(uuid.uuid4()) +USER_MESSAGE_FOLLOW_UP_CONTENT = "Thank you for the secret code." +USER_MESSAGE_FOLLOW_UP: List[MessageCreate] = [ + MessageCreate( + role="user", + content=USER_MESSAGE_FOLLOW_UP_CONTENT, + otid=USER_MESSAGE_FOLLOW_UP_OTID, + ) +] def get_secret_code_tool(input_text: str) -> str: @@ -95,17 +103,19 @@ def client(server_url: str) -> Letta: @pytest.fixture(scope="function") -def approval_tool_fixture(client: Letta): +def approval_tool_fixture(client: Letta) -> Tool: """ Creates and returns a tool that requires approval for testing. """ client.tools.upsert_base_tools() approval_tool = client.tools.upsert_from_function( func=get_secret_code_tool, - # default_requires_approval=True, switch to this once it is supported in sdk + default_requires_approval=True, ) yield approval_tool + client.tools.delete(tool_id=approval_tool.id) + @pytest.fixture(scope="function") def agent(client: Letta, approval_tool_fixture) -> AgentState: @@ -122,16 +132,13 @@ def agent(client: Letta, approval_tool_fixture) -> AgentState: embedding="openai/text-embedding-3-small", tags=["approval_test"], ) - client.agents.tools.modify_approval( - agent_id=agent_state.id, - tool_name=approval_tool_fixture.name, - requires_approval=True, - ) yield agent_state + client.agents.delete(agent_id=agent_state.id) + # ------------------------------ -# Test Cases +# Error Test Cases # ------------------------------ @@ -169,6 +176,11 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): ) +# ------------------------------ +# Request Test Cases +# ------------------------------ + + def test_send_message_with_requires_approval_tool( client: Letta, agent: AgentState, @@ -184,6 +196,52 @@ def test_send_message_with_requires_approval_tool( assert response.messages[1].message_type == "approval_request_message" +def test_send_message_after_turning_off_requires_approval( + client: Letta, + agent: AgentState, + approval_tool_fixture: Tool, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + approval_request_id = response.messages[0].id + + client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, + approval_request_id=approval_request_id, + ), + ], + ) + + client.agents.tools.modify_approval( + agent_id=agent.id, + tool_name=approval_tool_fixture.name, + requires_approval=False, + ) + + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + + assert response.messages is not None + assert len(response.messages) == 5 + assert response.messages[0].message_type == "reasoning_message" + assert response.messages[1].message_type == "tool_call_message" + assert response.messages[2].message_type == "tool_return_message" + assert response.messages[3].message_type == "reasoning_message" + assert response.messages[4].message_type == "assistant_message" + + +# ------------------------------ +# Approve Test Cases +# ------------------------------ + + def test_approve_tool_call_request( client: Letta, agent: AgentState, @@ -206,12 +264,13 @@ def test_approve_tool_call_request( ) assert response.messages is not None - assert len(response.messages) == 3 + assert len(response.messages) == 1 or len(response.messages) == 3 assert response.messages[0].message_type == "tool_return_message" assert response.messages[0].tool_call_id == tool_call_id assert response.messages[0].status == "success" - assert response.messages[1].message_type == "reasoning_message" - assert response.messages[2].message_type == "assistant_message" + if len(response.messages) == 3: + assert response.messages[1].message_type == "reasoning_message" + assert response.messages[2].message_type == "assistant_message" def test_approve_cursor_fetch( @@ -253,6 +312,42 @@ def test_approve_cursor_fetch( assert messages[4].message_type == "assistant_message" +def test_approve_and_follow_up( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + approval_request_id = response.messages[0].id + + client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, + approval_request_id=approval_request_id, + ), + ], + ) + + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + ) + + assert response.messages is not None + assert len(response.messages) == 2 + assert response.messages[0].message_type == "reasoning_message" + assert response.messages[1].message_type == "assistant_message" + + +# ------------------------------ +# Deny Test Cases +# ------------------------------ + + def test_deny_tool_call_request( client: Letta, agent: AgentState, @@ -323,3 +418,35 @@ def test_deny_cursor_fetch( assert messages[2].message_type == "user_message" # heartbeat assert messages[3].message_type == "reasoning_message" assert messages[4].message_type == "assistant_message" + + +def test_deny_and_follow_up( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + approval_request_id = response.messages[0].id + + client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=False, + approval_request_id=approval_request_id, + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", + ), + ], + ) + + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + ) + + assert response.messages is not None + assert len(response.messages) == 2 + assert response.messages[0].message_type == "reasoning_message" + assert response.messages[1].message_type == "assistant_message"