fix: use shared event + .athrow() to properly set stream_was_cancelled flag
**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.
**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]
**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation
**Tests:**
- test_tool_call with cancellation (OpenAI models) ✅
- test_approve_with_cancellation (approval flow + concurrent cancel) ✅
**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race
👾 Generated with [Letta Code](https://letta.com)
Co-authored-by: Letta <noreply@letta.com>
206 lines
6.5 KiB
Python
206 lines
6.5 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import threading
|
|
from typing import Any, List
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
from letta_client import AsyncLetta
|
|
from letta_client.types import MessageCreateParam
|
|
|
|
from letta.log import get_logger
|
|
from letta.schemas.agent import AgentState
|
|
from letta.schemas.enums import AgentType, JobStatus
|
|
from letta.schemas.llm_config import LLMConfig
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
|
filename = os.path.join(llm_config_dir, filename)
|
|
with open(filename, "r") as f:
|
|
config_data = json.load(f)
|
|
llm_config = LLMConfig(**config_data)
|
|
return llm_config
|
|
|
|
|
|
all_configs = [
|
|
"openai-gpt-4o-mini.json",
|
|
"openai-o3.json",
|
|
"openai-gpt-5.json",
|
|
"claude-4-5-sonnet.json",
|
|
"claude-4-1-opus.json",
|
|
"gemini-2.5-flash.json",
|
|
]
|
|
|
|
requested = os.getenv("LLM_CONFIG_FILE")
|
|
filenames = [requested] if requested else all_configs
|
|
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
|
|
|
|
|
def roll_dice(num_sides: int) -> int:
|
|
"""
|
|
Returns a random number between 1 and num_sides.
|
|
Args:
|
|
num_sides (int): The number of sides on the die.
|
|
Returns:
|
|
int: A random integer between 1 and num_sides, representing the die roll.
|
|
"""
|
|
import random
|
|
|
|
return random.randint(1, num_sides)
|
|
|
|
|
|
USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
|
|
MessageCreateParam(
|
|
role="user",
|
|
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
|
|
)
|
|
]
|
|
|
|
|
|
async def accumulate_chunks(chunks: Any) -> List[Any]:
|
|
"""
|
|
Accumulates chunks into a list of messages.
|
|
"""
|
|
messages = []
|
|
current_message = None
|
|
prev_message_type = None
|
|
async for chunk in chunks:
|
|
current_message_type = chunk.message_type
|
|
if prev_message_type != current_message_type:
|
|
messages.append(current_message)
|
|
current_message = chunk
|
|
else:
|
|
current_message = chunk
|
|
prev_message_type = current_message_type
|
|
messages.append(current_message)
|
|
return [m for m in messages if m is not None]
|
|
|
|
|
|
async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5):
|
|
await asyncio.sleep(delay)
|
|
await client.agents.messages.cancel(agent_id=agent_id)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server_url() -> str:
|
|
"""
|
|
Provides the URL for the Letta server.
|
|
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
|
and polls until it's accepting connections.
|
|
"""
|
|
|
|
def _run_server() -> None:
|
|
load_dotenv()
|
|
from letta.server.rest_api.app import start_server
|
|
|
|
start_server(debug=True)
|
|
|
|
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
|
|
|
if not os.getenv("LETTA_SERVER_URL"):
|
|
if os.getenv("LETTA_REDIS_HOST"):
|
|
print(f"Redis is configured at {os.getenv('LETTA_REDIS_HOST')}:{os.getenv('LETTA_REDIS_PORT', '6379')}")
|
|
|
|
thread = threading.Thread(target=_run_server, daemon=True)
|
|
thread.start()
|
|
|
|
timeout_seconds = 60
|
|
import time
|
|
|
|
import httpx
|
|
|
|
start_time = time.time()
|
|
while time.time() - start_time < timeout_seconds:
|
|
try:
|
|
response = httpx.get(url + "/v1/health", timeout=1.0, follow_redirects=True)
|
|
if response.status_code < 500:
|
|
break
|
|
except Exception:
|
|
pass
|
|
time.sleep(0.5)
|
|
else:
|
|
raise TimeoutError(f"Server at {url} did not become ready in {timeout_seconds}s")
|
|
|
|
return url
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def client(server_url: str) -> AsyncLetta:
|
|
"""
|
|
Creates and returns an asynchronous Letta REST client for testing.
|
|
"""
|
|
client_instance = AsyncLetta(base_url=server_url)
|
|
yield client_instance
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def agent_state(client: AsyncLetta) -> AgentState:
|
|
"""
|
|
Creates and returns an agent state for testing with a pre-configured agent.
|
|
The agent is configured with the roll_dice tool.
|
|
"""
|
|
dice_tool = await client.tools.upsert_from_function(func=roll_dice)
|
|
|
|
agent_state_instance = await client.agents.create(
|
|
agent_type=AgentType.letta_v1_agent,
|
|
name="test_agent",
|
|
include_base_tools=False,
|
|
tool_ids=[dice_tool.id],
|
|
model="openai/gpt-4o",
|
|
embedding="openai/text-embedding-3-small",
|
|
tags=["test"],
|
|
)
|
|
yield agent_state_instance
|
|
|
|
await client.agents.delete(agent_state_instance.id)
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("LETTA_REDIS_HOST"), reason="Redis is required for background streaming (set LETTA_REDIS_HOST to enable)")
|
|
@pytest.mark.parametrize(
|
|
"llm_config",
|
|
TESTED_LLM_CONFIGS,
|
|
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
|
)
|
|
@pytest.mark.asyncio(loop_scope="function")
|
|
async def test_background_streaming_cancellation(
|
|
disable_e2b_api_key: Any,
|
|
client: AsyncLetta,
|
|
agent_state: AgentState,
|
|
llm_config: LLMConfig,
|
|
) -> None:
|
|
agent_state = await client.agents.update(agent_id=agent_state.id, llm_config=llm_config)
|
|
|
|
delay = 1.5
|
|
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
|
|
|
|
response = await client.agents.messages.stream(
|
|
agent_id=agent_state.id,
|
|
messages=USER_MESSAGE_ROLL_DICE,
|
|
stream_tokens=True,
|
|
background=True,
|
|
)
|
|
messages = await accumulate_chunks(response)
|
|
run_id = messages[0].run_id if hasattr(messages[0], "run_id") else None
|
|
|
|
await _cancellation_task
|
|
|
|
if run_id:
|
|
run = await client.runs.retrieve(run_id=run_id)
|
|
assert run.status == JobStatus.cancelled
|
|
else:
|
|
runs = await client.runs.list(agent_id=agent_state.id, stop_reason="cancelled", limit=1)
|
|
assert len(list(runs)) == 1
|
|
run_id = runs.items[0].id
|
|
|
|
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
|
|
messages_from_stream = await accumulate_chunks(response)
|
|
assert len(messages_from_stream) > 0
|
|
|
|
# Verify the stream contains stop_reason: cancelled (from our new cancellation logic)
|
|
stop_reasons = [msg for msg in messages_from_stream if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
|
|
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason in stream, got {len(stop_reasons)}"
|
|
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"
|