fix: use shared event + .athrow() to properly set stream_was_cancelled flag
**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.
**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]
**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation
**Tests:**
- test_tool_call with cancellation (OpenAI models) ✅
- test_approve_with_cancellation (approval flow + concurrent cancel) ✅
**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race
👾 Generated with [Letta Code](https://letta.com)
Co-authored-by: Letta <noreply@letta.com>
1403 lines
45 KiB
Python
1403 lines
45 KiB
Python
import asyncio
|
|
import logging
|
|
import uuid
|
|
from typing import Any, List
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from letta_client import APIError, Letta
|
|
from letta_client.types import AgentState, MessageCreateParam, Tool
|
|
from letta_client.types.agents import ApprovalCreateParam
|
|
|
|
from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ------------------------------
|
|
# Helper Functions and Constants
|
|
# ------------------------------
|
|
|
|
USER_MESSAGE_OTID = str(uuid.uuid4())
|
|
USER_MESSAGE_CONTENT = "This is an automated test message. Call the get_secret_code_tool to get the code for text 'hello world'."
|
|
USER_MESSAGE_TEST_APPROVAL: List[MessageCreateParam] = [
|
|
MessageCreateParam(
|
|
role="user",
|
|
content=USER_MESSAGE_CONTENT,
|
|
otid=USER_MESSAGE_OTID,
|
|
)
|
|
]
|
|
FAKE_REQUEST_ID = str(uuid.uuid4())
|
|
SECRET_CODE = str(740845635798344975)
|
|
USER_MESSAGE_FOLLOW_UP_OTID = str(uuid.uuid4())
|
|
USER_MESSAGE_FOLLOW_UP_CONTENT = "Thank you for the secret code."
|
|
USER_MESSAGE_FOLLOW_UP: List[MessageCreateParam] = [
|
|
MessageCreateParam(
|
|
role="user",
|
|
content=USER_MESSAGE_FOLLOW_UP_CONTENT,
|
|
otid=USER_MESSAGE_FOLLOW_UP_OTID,
|
|
)
|
|
]
|
|
USER_MESSAGE_PARALLEL_TOOL_CALL_CONTENT = "This is an automated test message. Call the get_secret_code_tool 3 times in parallel for the following inputs: 'hello world', 'hello letta', 'hello test', and also call the roll_dice_tool once with a 16-sided dice."
|
|
USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreateParam] = [
|
|
MessageCreateParam(
|
|
role="user",
|
|
content=USER_MESSAGE_PARALLEL_TOOL_CALL_CONTENT,
|
|
otid=USER_MESSAGE_OTID,
|
|
)
|
|
]
|
|
|
|
|
|
def get_secret_code_tool(input_text: str) -> str:
|
|
"""
|
|
A tool that returns the secret code based on the input. This tool requires approval before execution.
|
|
Args:
|
|
input_text (str): The input text to process.
|
|
Returns:
|
|
str: The secret code based on the input text.
|
|
"""
|
|
return str(abs(hash(input_text)))
|
|
|
|
|
|
def roll_dice_tool(num_sides: int) -> str:
|
|
"""
|
|
A tool that returns a random number between 1 and num_sides.
|
|
Args:
|
|
num_sides (int): The number of sides on the die.
|
|
Returns:
|
|
str: The random number between 1 and num_sides.
|
|
"""
|
|
import random
|
|
|
|
return str(random.randint(1, num_sides))
|
|
|
|
|
|
def accumulate_chunks(stream):
|
|
messages = []
|
|
current_message = None
|
|
prev_message_type = None
|
|
|
|
for chunk in stream:
|
|
# Handle chunks that might not have message_type (like pings)
|
|
if not hasattr(chunk, "message_type"):
|
|
continue
|
|
|
|
current_message_type = getattr(chunk, "message_type", None)
|
|
|
|
if prev_message_type != current_message_type:
|
|
# Save the previous message if it exists
|
|
if current_message is not None:
|
|
messages.append(current_message)
|
|
# Start a new message
|
|
current_message = chunk
|
|
else:
|
|
# Accumulate content for same message type (token streaming)
|
|
if current_message is not None and hasattr(current_message, "content") and hasattr(chunk, "content"):
|
|
current_message.content += chunk.content
|
|
|
|
prev_message_type = current_message_type
|
|
|
|
# Don't forget the last message
|
|
if current_message is not None:
|
|
messages.append(current_message)
|
|
|
|
return [m for m in messages if m is not None]
|
|
|
|
|
|
def approve_tool_call(client: Letta, agent_id: str, tool_call_id: str):
|
|
client.agents.messages.create(
|
|
agent_id=agent_id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
|
|
# ------------------------------
|
|
# Fixtures
|
|
# ------------------------------
|
|
# Note: server_url and client fixtures are inherited from tests/conftest.py
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def approval_tool_fixture(client: Letta):
|
|
"""
|
|
Creates and returns a tool that requires approval for testing.
|
|
"""
|
|
approval_tool = client.tools.upsert_from_function(
|
|
func=get_secret_code_tool,
|
|
default_requires_approval=True,
|
|
)
|
|
yield approval_tool
|
|
|
|
client.tools.delete(tool_id=approval_tool.id)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def dice_tool_fixture(client: Letta):
|
|
dice_tool = client.tools.upsert_from_function(
|
|
func=roll_dice_tool,
|
|
)
|
|
yield dice_tool
|
|
|
|
client.tools.delete(tool_id=dice_tool.id)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def agent(client: Letta, approval_tool_fixture, dice_tool_fixture) -> AgentState:
|
|
"""
|
|
Creates and returns an agent state for testing with a pre-configured agent.
|
|
The agent is configured with the requires_approval_tool.
|
|
"""
|
|
agent_state = client.agents.create(
|
|
name="approval_test_agent",
|
|
agent_type="letta_v1_agent",
|
|
include_base_tools=False,
|
|
tool_ids=[approval_tool_fixture.id, dice_tool_fixture.id],
|
|
include_base_tool_rules=False,
|
|
tool_rules=[],
|
|
model="anthropic/claude-sonnet-4-5-20250929",
|
|
embedding="openai/text-embedding-3-small",
|
|
tags=["approval_test"],
|
|
)
|
|
# Enable parallel tool calls for testing
|
|
agent_state = client.agents.update(agent_id=agent_state.id, parallel_tool_calls=True)
|
|
yield agent_state
|
|
|
|
client.agents.delete(agent_id=agent_state.id)
|
|
|
|
|
|
# ------------------------------
|
|
# Error Test Cases
|
|
# ------------------------------
|
|
|
|
|
|
def test_send_approval_without_pending_request(client, agent):
|
|
with pytest.raises(APIError, match="No tool call is currently awaiting approval"):
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": FAKE_REQUEST_ID,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
|
|
def test_send_user_message_with_pending_request(client, agent):
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
print("RESPONSE", response)
|
|
for message in response.messages:
|
|
print("MESSAGE", message)
|
|
|
|
with pytest.raises(APIError, match="Please approve or deny the pending request before continuing"):
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
)
|
|
|
|
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
|
|
|
|
|
|
def test_send_approval_message_with_incorrect_request_id(client, agent):
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
|
|
with pytest.raises(APIError, match="Invalid tool call IDs"):
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": FAKE_REQUEST_ID,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
|
|
|
|
|
|
# ------------------------------
|
|
# Request Test Cases
|
|
# ------------------------------
|
|
|
|
|
|
def test_invoke_approval_request(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
|
|
messages = response.messages
|
|
|
|
assert messages is not None
|
|
assert messages[-1].message_type == "approval_request_message"
|
|
assert messages[-1].tool_call is not None
|
|
assert messages[-1].tool_call.name == "get_secret_code_tool"
|
|
assert messages[-1].tool_calls is not None
|
|
assert len(messages[-1].tool_calls) == 1
|
|
assert messages[-1].tool_calls[0].name == "get_secret_code_tool"
|
|
|
|
# v3/v1 path: approval request tool args must not include request_heartbeat
|
|
import json as _json
|
|
|
|
_args = _json.loads(messages[-1].tool_call.arguments)
|
|
assert "request_heartbeat" not in _args
|
|
|
|
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
|
|
|
|
# Test pending_approval relationship field
|
|
agent_with_pending = client.agents.retrieve(agent_id=agent.id, include=["agent.pending_approval"])
|
|
assert agent_with_pending.pending_approval is not None
|
|
# Client SDK returns it as a dict, so use dict access
|
|
assert agent_with_pending.pending_approval["tool_call"]["name"] == "get_secret_code_tool"
|
|
assert len(agent_with_pending.pending_approval["tool_calls"]) > 0
|
|
assert agent_with_pending.pending_approval["tool_calls"][0]["name"] == "get_secret_code_tool"
|
|
assert agent_with_pending.pending_approval["tool_calls"][0]["tool_call_id"] == response.messages[-1].tool_call.tool_call_id
|
|
|
|
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
|
|
|
|
# After approval, pending_approval should be None (latest message is no longer approval request)
|
|
agent_after_approval = client.agents.retrieve(agent_id=agent.id, include=["agent.pending_approval"])
|
|
assert agent_after_approval.pending_approval is None
|
|
|
|
|
|
def test_invoke_approval_request_stream(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert messages[-3].message_type == "approval_request_message"
|
|
assert messages[-3].tool_call is not None
|
|
assert messages[-3].tool_call.name == "get_secret_code_tool"
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
|
|
|
|
approve_tool_call(client, agent.id, messages[-3].tool_call.tool_call_id)
|
|
|
|
|
|
def test_invoke_tool_after_turning_off_requires_approval(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
approval_tool_fixture: Tool,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
messages = accumulate_chunks(response)
|
|
|
|
client.agents.tools.update_approval(
|
|
agent_id=agent.id,
|
|
tool_name=approval_tool_fixture.name,
|
|
body_requires_approval=False,
|
|
)
|
|
|
|
response = client.agents.messages.stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert 6 <= len(messages) <= 9
|
|
idx = 0
|
|
|
|
assert messages[idx].message_type == "reasoning_message"
|
|
idx += 1
|
|
|
|
try:
|
|
assert messages[idx].message_type == "assistant_message"
|
|
idx += 1
|
|
except:
|
|
pass
|
|
|
|
assert messages[idx].message_type == "tool_call_message"
|
|
idx += 1
|
|
assert messages[idx].message_type == "tool_return_message"
|
|
idx += 1
|
|
|
|
assert messages[idx].message_type == "reasoning_message"
|
|
idx += 1
|
|
try:
|
|
assert messages[idx].message_type == "assistant_message"
|
|
idx += 1
|
|
except:
|
|
assert messages[idx].message_type == "tool_call_message"
|
|
idx += 1
|
|
assert messages[idx].message_type == "tool_return_message"
|
|
idx += 1
|
|
|
|
|
|
# ------------------------------
|
|
# Approve Test Cases
|
|
# ------------------------------
|
|
|
|
|
|
def test_approve_tool_call_request(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert messages[0].message_type == "tool_return_message"
|
|
assert messages[0].tool_call_id == tool_call_id
|
|
assert messages[0].status == "success"
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
def test_approve_cursor_fetch(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
last_message_id = response.messages[0].id
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
|
|
assert messages[0].message_type == "user_message"
|
|
assert messages[-1].message_type == "approval_request_message"
|
|
# Ensure no request_heartbeat on approval request
|
|
import json as _json
|
|
|
|
_args = _json.loads(messages[-1].tool_call.arguments)
|
|
assert "request_heartbeat" not in _args
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
|
|
assert messages[0].message_type == "approval_response_message"
|
|
assert messages[0].approval_request_id == tool_call_id
|
|
assert messages[0].approve is True
|
|
assert messages[0].approvals[0].approve is True
|
|
assert messages[0].approvals[0].tool_call_id == tool_call_id
|
|
assert messages[1].message_type == "tool_return_message"
|
|
assert messages[1].status == "success"
|
|
|
|
|
|
def test_approve_with_context_check(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
try:
|
|
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
|
|
except Exception as e:
|
|
if len(messages) > 4:
|
|
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
|
raise e
|
|
|
|
|
|
def test_approve_and_follow_up(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert messages[0].message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
def test_approve_and_follow_up_with_error(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
# Mock the streaming adapter to return llm invocation failure on the follow up turn
|
|
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
|
|
with pytest.raises(APIError, match="TEST: Mocked error"):
|
|
messages = accumulate_chunks(response)
|
|
|
|
# Ensure that agent is not bricked
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert len(messages) == 4 or len(messages) == 5
|
|
assert messages[0].message_type == "reasoning_message"
|
|
if len(messages) == 4:
|
|
assert messages[1].message_type == "assistant_message"
|
|
else:
|
|
assert messages[1].message_type == "tool_call_message"
|
|
assert messages[2].message_type == "tool_return_message"
|
|
|
|
|
|
def test_approve_with_user_message(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": "The secret code should not contain any special characters.",
|
|
},
|
|
],
|
|
)
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert messages[0].message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
# ------------------------------
|
|
# Deny Test Cases
|
|
# ------------------------------
|
|
|
|
|
|
def test_deny_tool_call_request(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": False,
|
|
"tool_call_id": tool_call_id,
|
|
"reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
if messages[0].message_type == "assistant_message":
|
|
assert SECRET_CODE in messages[0].content
|
|
elif messages[1].message_type == "assistant_message":
|
|
assert SECRET_CODE in messages[1].content
|
|
|
|
|
|
def test_deny_cursor_fetch(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
last_message_id = response.messages[0].id
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
|
|
assert messages[0].message_type == "user_message"
|
|
assert messages[-1].message_type == "approval_request_message"
|
|
assert messages[-1].tool_call.tool_call_id == tool_call_id
|
|
# Ensure no request_heartbeat on approval request
|
|
# import json as _json
|
|
|
|
# _args = _json.loads(messages[2].tool_call.arguments)
|
|
# assert "request_heartbeat" not in _args
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": False,
|
|
"tool_call_id": tool_call_id,
|
|
"reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
|
|
assert messages[0].message_type == "approval_response_message"
|
|
assert messages[0].approvals[0].approve == False
|
|
assert messages[0].approvals[0].tool_call_id == tool_call_id
|
|
assert messages[0].approvals[0].reason == f"You don't need to call the tool, the secret code is {SECRET_CODE}"
|
|
assert messages[1].message_type == "tool_return_message"
|
|
assert messages[1].status == "error"
|
|
|
|
|
|
def test_deny_with_context_check(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": False,
|
|
"tool_call_id": tool_call_id,
|
|
"reason": "Cancelled by user. Instead of responding, wait for next user input before replying.",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
try:
|
|
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
|
|
except Exception as e:
|
|
if len(messages) > 4:
|
|
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
|
raise e
|
|
|
|
|
|
def test_deny_and_follow_up(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": False,
|
|
"tool_call_id": tool_call_id,
|
|
"reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert len(messages) > 2
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
def test_deny_and_follow_up_with_error(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
# Mock the streaming adapter to return llm invocation failure on the follow up turn
|
|
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": False,
|
|
"tool_call_id": tool_call_id,
|
|
"reason": f"You don't need to call the tool, the secret code is {SECRET_CODE}",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
|
|
with pytest.raises(APIError, match="TEST: Mocked error"):
|
|
messages = accumulate_chunks(response)
|
|
|
|
# Ensure that agent is not bricked
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert len(messages) > 2
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
def test_deny_with_user_message(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": False,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": f"Actually, you don't need to call the tool, the secret code is {SECRET_CODE}",
|
|
},
|
|
],
|
|
)
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert len(messages) > 2
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
# --------------------------------
|
|
# Client-Side Execution Test Cases
|
|
# --------------------------------
|
|
|
|
|
|
def test_client_side_tool_call_request(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"tool_return": SECRET_CODE,
|
|
"status": "success",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
if messages[0].message_type == "assistant_message":
|
|
assert SECRET_CODE in messages[1].content
|
|
elif messages[1].message_type == "assistant_message":
|
|
assert SECRET_CODE in messages[2].content
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
def test_client_side_tool_call_cursor_fetch(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
last_message_id = response.messages[0].id
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
|
|
assert messages[0].message_type == "user_message"
|
|
assert messages[-1].message_type == "approval_request_message"
|
|
assert messages[-1].tool_call.tool_call_id == tool_call_id
|
|
# Ensure no request_heartbeat on approval request
|
|
# import json as _json
|
|
|
|
# _args = _json.loads(messages[2].tool_call.arguments)
|
|
# assert "request_heartbeat" not in _args
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"tool_return": SECRET_CODE,
|
|
"status": "success",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id).items
|
|
assert messages[0].message_type == "approval_response_message"
|
|
assert messages[0].approvals[0].type == "tool"
|
|
assert messages[0].approvals[0].tool_call_id == tool_call_id
|
|
assert messages[0].approvals[0].tool_return == SECRET_CODE
|
|
assert messages[0].approvals[0].status == "success"
|
|
assert messages[1].message_type == "tool_return_message"
|
|
assert messages[1].status == "success"
|
|
assert messages[1].tool_call_id == tool_call_id
|
|
assert messages[1].tool_return == SECRET_CODE
|
|
|
|
|
|
def test_client_side_tool_call_with_context_check(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"tool_return": SECRET_CODE,
|
|
"status": "success",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
try:
|
|
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
|
|
except Exception as e:
|
|
if len(messages) > 4:
|
|
raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.")
|
|
raise e
|
|
|
|
|
|
def test_client_side_tool_call_and_follow_up(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"tool_return": SECRET_CODE,
|
|
"status": "success",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert len(messages) > 2
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
def test_client_side_tool_call_and_follow_up_with_error(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
# Mock the streaming adapter to return llm invocation failure on the follow up turn
|
|
with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")):
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"tool_return": SECRET_CODE,
|
|
"status": "success",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
|
|
with pytest.raises(APIError, match="TEST: Mocked error"):
|
|
messages = accumulate_chunks(response)
|
|
|
|
# Ensure that agent is not bricked
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert len(messages) > 2
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
def test_client_side_tool_call_with_user_message(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"tool_return": SECRET_CODE,
|
|
"status": "success",
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": "The secret code should not contain any special characters.",
|
|
},
|
|
],
|
|
)
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert len(messages) > 2
|
|
assert messages[-2].message_type == "stop_reason"
|
|
assert messages[-1].message_type == "usage_statistics"
|
|
|
|
|
|
def test_parallel_tool_calling(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1).items[0].id
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_PARALLEL_TOOL_CALL,
|
|
)
|
|
|
|
messages = response.messages
|
|
|
|
assert messages is not None
|
|
assert messages[-2].message_type == "tool_call_message"
|
|
assert len(messages[-2].tool_calls) == 1
|
|
assert messages[-2].tool_calls[0].name == "roll_dice_tool"
|
|
assert "6" in messages[-2].tool_calls[0].arguments
|
|
dice_tool_call_id = messages[-2].tool_calls[0].tool_call_id
|
|
|
|
assert messages[-1].message_type == "approval_request_message"
|
|
assert messages[-1].tool_call is not None
|
|
assert messages[-1].tool_call.name == "get_secret_code_tool"
|
|
|
|
assert len(messages[-1].tool_calls) == 3
|
|
assert messages[-1].tool_calls[0].name == "get_secret_code_tool"
|
|
assert "hello world" in messages[-1].tool_calls[0].arguments
|
|
approve_tool_call_id = messages[-1].tool_calls[0].tool_call_id
|
|
assert messages[-1].tool_calls[1].name == "get_secret_code_tool"
|
|
assert "hello letta" in messages[-1].tool_calls[1].arguments
|
|
deny_tool_call_id = messages[-1].tool_calls[1].tool_call_id
|
|
assert messages[-1].tool_calls[2].name == "get_secret_code_tool"
|
|
assert "hello test" in messages[-1].tool_calls[2].arguments
|
|
client_side_tool_call_id = messages[-1].tool_calls[2].tool_call_id
|
|
|
|
# ensure context is not bricked
|
|
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
|
|
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": approve_tool_call_id,
|
|
},
|
|
{
|
|
"type": "approval",
|
|
"approve": False,
|
|
"tool_call_id": deny_tool_call_id,
|
|
},
|
|
{
|
|
"type": "tool",
|
|
"tool_call_id": client_side_tool_call_id,
|
|
"tool_return": SECRET_CODE,
|
|
"status": "success",
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
messages = response.messages
|
|
|
|
assert messages is not None
|
|
assert len(messages) == 1 or len(messages) == 3 or len(messages) == 4
|
|
assert messages[0].message_type == "tool_return_message"
|
|
assert len(messages[0].tool_returns) == 4
|
|
for tool_return in messages[0].tool_returns:
|
|
if tool_return.tool_call_id == approve_tool_call_id:
|
|
assert tool_return.status == "success"
|
|
elif tool_return.tool_call_id == deny_tool_call_id:
|
|
assert tool_return.status == "error"
|
|
elif tool_return.tool_call_id == client_side_tool_call_id:
|
|
assert tool_return.status == "success"
|
|
assert tool_return.tool_return == SECRET_CODE
|
|
else:
|
|
assert tool_return.tool_call_id == dice_tool_call_id
|
|
assert tool_return.status == "success"
|
|
if len(messages) == 3:
|
|
assert messages[1].message_type == "reasoning_message"
|
|
assert messages[2].message_type == "assistant_message"
|
|
elif len(messages) == 4:
|
|
assert messages[1].message_type == "reasoning_message"
|
|
assert messages[2].message_type == "tool_call_message"
|
|
assert messages[3].message_type == "tool_return_message"
|
|
|
|
# ensure context is not bricked
|
|
client.get(f"/v1/agents/{agent.id}/context", cast_to=dict[str, Any])
|
|
|
|
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor).items
|
|
assert len(messages) > 6
|
|
assert messages[0].message_type == "user_message"
|
|
assert messages[1].message_type == "reasoning_message"
|
|
assert messages[2].message_type == "assistant_message"
|
|
assert messages[3].message_type == "tool_call_message"
|
|
assert messages[4].message_type == "approval_request_message"
|
|
assert messages[5].message_type == "approval_response_message"
|
|
assert messages[6].message_type == "tool_return_message"
|
|
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
stream_tokens=True,
|
|
)
|
|
|
|
messages = accumulate_chunks(response)
|
|
|
|
assert messages is not None
|
|
assert len(messages) == 4
|
|
assert messages[0].message_type == "reasoning_message"
|
|
assert messages[1].message_type == "assistant_message"
|
|
assert messages[2].message_type == "stop_reason"
|
|
assert messages[3].message_type == "usage_statistics"
|
|
|
|
|
|
def test_agent_records_last_stop_reason_after_approval_flow(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
"""
|
|
Test that the agent's last_stop_reason is properly updated after a human-in-the-loop flow.
|
|
This verifies the integration between run completion and agent state updates.
|
|
"""
|
|
# Get initial agent state
|
|
initial_agent = client.agents.retrieve(agent_id=agent.id)
|
|
initial_stop_reason = initial_agent.last_stop_reason
|
|
|
|
# Trigger approval request
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
|
|
# Verify we got an approval request
|
|
messages = response.messages
|
|
assert messages is not None
|
|
assert messages[-1].message_type == "approval_request_message"
|
|
|
|
# Check agent after approval request (run should be paused with requires_approval)
|
|
agent_after_request = client.agents.retrieve(agent_id=agent.id)
|
|
assert agent_after_request.last_stop_reason == "requires_approval"
|
|
|
|
# Approve the tool call
|
|
approve_tool_call(client, agent.id, response.messages[-1].tool_call.tool_call_id)
|
|
|
|
# Check agent after approval (run should complete with end_turn or similar)
|
|
agent_after_approval = client.agents.retrieve(agent_id=agent.id)
|
|
# After approval and run completion, stop reason should be updated (could be end_turn or other terminal reason)
|
|
assert agent_after_approval.last_stop_reason is not None
|
|
assert agent_after_approval.last_stop_reason != initial_stop_reason # Should be different from initial
|
|
|
|
# Send follow-up message to complete the flow
|
|
response2 = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_FOLLOW_UP,
|
|
)
|
|
|
|
# Verify final agent state has the most recent stop reason
|
|
final_agent = client.agents.retrieve(agent_id=agent.id)
|
|
assert final_agent.last_stop_reason is not None
|
|
|
|
|
|
def test_approve_with_cancellation(
|
|
client: Letta,
|
|
agent: AgentState,
|
|
) -> None:
|
|
"""
|
|
Test that when approval and cancellation happen simultaneously,
|
|
the stream returns stop_reason: cancelled and stream_was_cancelled is set.
|
|
"""
|
|
import threading
|
|
import time
|
|
|
|
# Step 1: Send message that triggers approval request
|
|
response = client.agents.messages.create(
|
|
agent_id=agent.id,
|
|
messages=USER_MESSAGE_TEST_APPROVAL,
|
|
)
|
|
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
|
|
|
# Step 2: Start cancellation in background thread
|
|
def cancel_after_delay():
|
|
time.sleep(0.3) # Wait for stream to start
|
|
client.agents.messages.cancel(agent_id=agent.id)
|
|
|
|
cancel_thread = threading.Thread(target=cancel_after_delay, daemon=True)
|
|
cancel_thread.start()
|
|
|
|
# Step 3: Start approval stream (will be cancelled during processing)
|
|
response = client.agents.messages.stream(
|
|
agent_id=agent.id,
|
|
messages=[
|
|
{
|
|
"type": "approval",
|
|
"approvals": [
|
|
{
|
|
"type": "approval",
|
|
"approve": True,
|
|
"tool_call_id": tool_call_id,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
stream_tokens=True,
|
|
)
|
|
|
|
# Step 4: Accumulate chunks
|
|
messages = accumulate_chunks(response)
|
|
|
|
# Step 5: Verify we got chunks AND a cancelled stop reason
|
|
assert len(messages) > 0, "Should receive at least some chunks before cancellation"
|
|
|
|
# Find stop_reason in messages
|
|
stop_reasons = [msg for msg in messages if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
|
|
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason, got {len(stop_reasons)}"
|
|
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"
|
|
|
|
# Step 6: Verify run status is cancelled
|
|
runs = client.runs.list(agent_ids=[agent.id])
|
|
latest_run = runs.items[0]
|
|
assert latest_run.status == "cancelled", f"Expected run status 'cancelled', got '{latest_run.status}'"
|
|
|
|
# Wait for cancel thread to finish
|
|
cancel_thread.join(timeout=1.0)
|
|
|
|
logger.info(f"✅ Test passed: approval with cancellation handled correctly, received {len(messages)} chunks")
|