Files
letta-server/tests/integration_test_human_in_the_loop.py
cthomas f3112f75a3 feat: add human in the loop tests to CI (#4335)
feat: add human in the loop tests-to-ci
2025-09-01 11:10:31 -07:00

247 lines
7.5 KiB
Python

import os
import threading
import time
import uuid
from typing import Any, List
import pytest
import requests
from dotenv import load_dotenv
from letta_client import ApprovalCreate, Letta, MessageCreate
from letta_client.core.api_error import ApiError
from letta.log import get_logger
from letta.schemas.agent import AgentState
logger = get_logger(__name__)
# ------------------------------
# Helper Functions and Constants
# ------------------------------
USER_MESSAGE_OTID = str(uuid.uuid4())
USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'."
USER_MESSAGE_TEST_APPROVAL: List[MessageCreate] = [
MessageCreate(
role="user",
content=USER_MESSAGE_CONTENT,
otid=USER_MESSAGE_OTID,
)
]
FAKE_REQUEST_ID = str(uuid.uuid4())
SECRET_CODE = str(740845635798344975)
def get_secret_code_tool(input_text: str) -> str:
"""
A tool that returns the secret code based on the input. This tool requires approval before execution.
Args:
input_text (str): The input text to process.
Returns:
str: The secret code based on the input text.
"""
return str(abs(hash(input_text)))
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
def approval_tool_fixture(client: Letta):
"""
Creates and returns a tool that requires approval for testing.
"""
client.tools.upsert_base_tools()
approval_tool = client.tools.upsert_from_function(
func=get_secret_code_tool,
# default_requires_approval=True, switch to this once it is supported in sdk
)
yield approval_tool
@pytest.fixture(scope="function")
def agent(client: Letta, approval_tool_fixture) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is configured with the requires_approval_tool.
"""
send_message_tool = client.tools.list(name="send_message")[0]
agent_state = client.agents.create(
name="approval_test_agent",
include_base_tools=False,
tool_ids=[send_message_tool.id, approval_tool_fixture.id],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-3-small",
tags=["approval_test"],
)
client.agents.tools.modify_approval(
agent_id=agent_state.id,
tool_name=approval_tool_fixture.name,
requires_approval=True,
)
yield agent_state
# ------------------------------
# Test Cases
# ------------------------------
def test_send_approval_without_pending_request(client, agent):
with pytest.raises(ApiError, match="No tool call is currently awaiting approval"):
client.agents.messages.create(
agent_id=agent.id,
messages=[ApprovalCreate(approve=True, approval_request_id=FAKE_REQUEST_ID)],
)
def test_send_user_message_with_pending_request(client, agent):
client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
with pytest.raises(ApiError, match="Please approve or deny the pending request before continuing"):
client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="hi")],
)
def test_send_approval_message_with_incorrect_request_id(client, agent):
client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
with pytest.raises(ApiError, match="Invalid approval request ID"):
client.agents.messages.create(
agent_id=agent.id,
messages=[ApprovalCreate(approve=True, approval_request_id=FAKE_REQUEST_ID)],
)
def test_send_message_with_requires_approval_tool(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
assert response.messages is not None
assert len(response.messages) == 2
assert response.messages[0].message_type == "reasoning_message"
assert response.messages[1].message_type == "approval_request_message"
def test_approve_tool_call_request(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
tool_call_id = response.messages[1].tool_call.tool_call_id
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True,
approval_request_id=approval_request_id,
),
],
)
assert response.messages is not None
assert len(response.messages) == 3
assert response.messages[0].message_type == "tool_return_message"
assert response.messages[0].tool_call_id == tool_call_id
assert response.messages[0].status == "success"
assert response.messages[1].message_type == "reasoning_message"
assert response.messages[2].message_type == "assistant_message"
def test_deny_tool_call_request(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
tool_call_id = response.messages[1].tool_call.tool_call_id
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=False,
approval_request_id=approval_request_id,
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}",
),
],
)
assert response.messages is not None
assert len(response.messages) == 3
assert response.messages[0].message_type == "tool_return_message"
assert response.messages[0].tool_call_id == tool_call_id
assert response.messages[0].status == "error"
assert response.messages[1].message_type == "reasoning_message"
assert response.messages[2].message_type == "assistant_message"
assert SECRET_CODE in response.messages[2].content