feat: add more tests for hitl (#4339)

This commit is contained in:
cthomas
2025-09-01 16:23:06 -07:00
committed by GitHub
parent 71a5eaa262
commit f3df0433ae

View File

@@ -2,16 +2,15 @@ import os
import threading
import time
import uuid
from typing import Any, List
from typing import List
import pytest
import requests
from dotenv import load_dotenv
from letta_client import ApprovalCreate, Letta, MessageCreate
from letta_client import AgentState, ApprovalCreate, Letta, MessageCreate, Tool
from letta_client.core.api_error import ApiError
from letta.log import get_logger
from letta.schemas.agent import AgentState
logger = get_logger(__name__)
@@ -30,6 +29,15 @@ USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [
]
FAKE_REQUEST_ID = str(uuid.uuid4())
SECRET_CODE = str(740845635798344975)
USER_MESSAGE_FOLLOW_UP_OTID = str(uuid.uuid4())
USER_MESSAGE_FOLLOW_UP_CONTENT = "Thank you for the secret code."
USER_MESSAGE_FOLLOW_UP: List[MessageCreate] = [
MessageCreate(
role="user",
content=USER_MESSAGE_FOLLOW_UP_CONTENT,
otid=USER_MESSAGE_FOLLOW_UP_OTID,
)
]
def get_secret_code_tool(input_text: str) -> str:
@@ -95,17 +103,19 @@ def client(server_url: str) -> Letta:
@pytest.fixture(scope="function")
def approval_tool_fixture(client: Letta):
def approval_tool_fixture(client: Letta) -> Tool:
"""
Creates and returns a tool that requires approval for testing.
"""
client.tools.upsert_base_tools()
approval_tool = client.tools.upsert_from_function(
func=get_secret_code_tool,
# default_requires_approval=True, switch to this once it is supported in sdk
default_requires_approval=True,
)
yield approval_tool
client.tools.delete(tool_id=approval_tool.id)
@pytest.fixture(scope="function")
def agent(client: Letta, approval_tool_fixture) -> AgentState:
@@ -122,16 +132,13 @@ def agent(client: Letta, approval_tool_fixture) -> AgentState:
embedding="openai/text-embedding-3-small",
tags=["approval_test"],
)
client.agents.tools.modify_approval(
agent_id=agent_state.id,
tool_name=approval_tool_fixture.name,
requires_approval=True,
)
yield agent_state
client.agents.delete(agent_id=agent_state.id)
# ------------------------------
# Test Cases
# Error Test Cases
# ------------------------------
@@ -169,6 +176,11 @@ def test_send_approval_message_with_incorrect_request_id(client, agent):
)
# ------------------------------
# Request Test Cases
# ------------------------------
def test_send_message_with_requires_approval_tool(
client: Letta,
agent: AgentState,
@@ -184,6 +196,52 @@ def test_send_message_with_requires_approval_tool(
assert response.messages[1].message_type == "approval_request_message"
def test_send_message_after_turning_off_requires_approval(
client: Letta,
agent: AgentState,
approval_tool_fixture: Tool,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True,
approval_request_id=approval_request_id,
),
],
)
client.agents.tools.modify_approval(
agent_id=agent.id,
tool_name=approval_tool_fixture.name,
requires_approval=False,
)
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
assert response.messages is not None
assert len(response.messages) == 5
assert response.messages[0].message_type == "reasoning_message"
assert response.messages[1].message_type == "tool_call_message"
assert response.messages[2].message_type == "tool_return_message"
assert response.messages[3].message_type == "reasoning_message"
assert response.messages[4].message_type == "assistant_message"
# ------------------------------
# Approve Test Cases
# ------------------------------
def test_approve_tool_call_request(
client: Letta,
agent: AgentState,
@@ -206,12 +264,13 @@ def test_approve_tool_call_request(
)
assert response.messages is not None
assert len(response.messages) == 3
assert len(response.messages) == 1 or len(response.messages) == 3
assert response.messages[0].message_type == "tool_return_message"
assert response.messages[0].tool_call_id == tool_call_id
assert response.messages[0].status == "success"
assert response.messages[1].message_type == "reasoning_message"
assert response.messages[2].message_type == "assistant_message"
if len(response.messages) == 3:
assert response.messages[1].message_type == "reasoning_message"
assert response.messages[2].message_type == "assistant_message"
def test_approve_cursor_fetch(
@@ -253,6 +312,42 @@ def test_approve_cursor_fetch(
assert messages[4].message_type == "assistant_message"
def test_approve_and_follow_up(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True,
approval_request_id=approval_request_id,
),
],
)
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
assert response.messages is not None
assert len(response.messages) == 2
assert response.messages[0].message_type == "reasoning_message"
assert response.messages[1].message_type == "assistant_message"
# ------------------------------
# Deny Test Cases
# ------------------------------
def test_deny_tool_call_request(
client: Letta,
agent: AgentState,
@@ -323,3 +418,35 @@ def test_deny_cursor_fetch(
assert messages[2].message_type == "user_message" # heartbeat
assert messages[3].message_type == "reasoning_message"
assert messages[4].message_type == "assistant_message"
def test_deny_and_follow_up(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=False,
approval_request_id=approval_request_id,
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}",
),
],
)
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
)
assert response.messages is not None
assert len(response.messages) == 2
assert response.messages[0].message_type == "reasoning_message"
assert response.messages[1].message_type == "assistant_message"