Files
letta-server/tests/integration_test_cancellation.py
cthomas c162de5127 fix: use shared event + .athrow() to properly set stream_was_cancelle… (#9019)
fix: use shared event + .athrow() to properly set stream_was_cancelled flag

**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.

**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]

**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation

**Tests:**
- test_tool_call with cancellation (OpenAI models) 
- test_approve_with_cancellation (approval flow + concurrent cancel) 

**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race

👾 Generated with [Letta Code](https://letta.com)

Co-authored-by: Letta <noreply@letta.com>
2026-01-29 12:44:04 -08:00

206 lines
6.5 KiB
Python

import asyncio
import json
import os
import threading
from typing import Any, List
import pytest
from dotenv import load_dotenv
from letta_client import AsyncLetta
from letta_client.types import MessageCreateParam
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import AgentType, JobStatus
from letta.schemas.llm_config import LLMConfig
logger = get_logger(__name__)
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
filename = os.path.join(llm_config_dir, filename)
with open(filename, "r") as f:
config_data = json.load(f)
llm_config = LLMConfig(**config_data)
return llm_config
all_configs = [
"openai-gpt-4o-mini.json",
"openai-o3.json",
"openai-gpt-5.json",
"claude-4-5-sonnet.json",
"claude-4-1-opus.json",
"gemini-2.5-flash.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
def roll_dice(num_sides: int) -> int:
"""
Returns a random number between 1 and num_sides.
Args:
num_sides (int): The number of sides on the die.
Returns:
int: A random integer between 1 and num_sides, representing the die roll.
"""
import random
return random.randint(1, num_sides)
USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
)
]
async def accumulate_chunks(chunks: Any) -> List[Any]:
"""
Accumulates chunks into a list of messages.
"""
messages = []
current_message = None
prev_message_type = None
async for chunk in chunks:
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
messages.append(current_message)
current_message = chunk
else:
current_message = chunk
prev_message_type = current_message_type
messages.append(current_message)
return [m for m in messages if m is not None]
async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5):
await asyncio.sleep(delay)
await client.agents.messages.cancel(agent_id=agent_id)
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
if os.getenv("LETTA_REDIS_HOST"):
print(f"Redis is configured at {os.getenv('LETTA_REDIS_HOST')}:{os.getenv('LETTA_REDIS_PORT', '6379')}")
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
timeout_seconds = 60
import time
import httpx
start_time = time.time()
while time.time() - start_time < timeout_seconds:
try:
response = httpx.get(url + "/v1/health", timeout=1.0, follow_redirects=True)
if response.status_code < 500:
break
except Exception:
pass
time.sleep(0.5)
else:
raise TimeoutError(f"Server at {url} did not become ready in {timeout_seconds}s")
return url
@pytest.fixture(scope="function")
async def client(server_url: str) -> AsyncLetta:
"""
Creates and returns an asynchronous Letta REST client for testing.
"""
client_instance = AsyncLetta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
async def agent_state(client: AsyncLetta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is configured with the roll_dice tool.
"""
dice_tool = await client.tools.upsert_from_function(func=roll_dice)
agent_state_instance = await client.agents.create(
agent_type=AgentType.letta_v1_agent,
name="test_agent",
include_base_tools=False,
tool_ids=[dice_tool.id],
model="openai/gpt-4o",
embedding="openai/text-embedding-3-small",
tags=["test"],
)
yield agent_state_instance
await client.agents.delete(agent_state_instance.id)
@pytest.mark.skipif(not os.getenv("LETTA_REDIS_HOST"), reason="Redis is required for background streaming (set LETTA_REDIS_HOST to enable)")
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
@pytest.mark.asyncio(loop_scope="function")
async def test_background_streaming_cancellation(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
agent_state = await client.agents.update(agent_id=agent_state.id, llm_config=llm_config)
delay = 1.5
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
response = await client.agents.messages.stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=True,
background=True,
)
messages = await accumulate_chunks(response)
run_id = messages[0].run_id if hasattr(messages[0], "run_id") else None
await _cancellation_task
if run_id:
run = await client.runs.retrieve(run_id=run_id)
assert run.status == JobStatus.cancelled
else:
runs = await client.runs.list(agent_id=agent_state.id, stop_reason="cancelled", limit=1)
assert len(list(runs)) == 1
run_id = runs.items[0].id
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
messages_from_stream = await accumulate_chunks(response)
assert len(messages_from_stream) > 0
# Verify the stream contains stop_reason: cancelled (from our new cancellation logic)
stop_reasons = [msg for msg in messages_from_stream if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason in stream, got {len(stop_reasons)}"
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"