fix(core): Fix agent loop continuing after cancellation in letta_agent_v3 [LET-6006] (#5905)
* Fix agent loop continuing after cancellation in letta_agent_v3 Bug: When a run is cancelled, _check_run_cancellation() sets self.should_continue=False and returns early from _step(), but the outer for loop (line 245) continues to the next iteration, executing subsequent steps even though cancellation was requested. Symptom: User hits cancel during step 1, backend marks run as cancelled, but agent continues executing steps 2, 3, etc. Root cause: After the 'async for chunk in response' loop completes (line 255), there was no check of self.should_continue before continuing to the next iteration of the outer step loop. Fix: Added 'if not self.should_continue: break' check after the inner loop to exit the outer step loop when cancellation is detected. This makes v3 consistent with v2 which already had this check (line 306-307). 🐾 Generated with [Letta Code](https://letta.com) Co-authored-by: Letta <noreply@letta.com> * add integration tests * fix: misc fixes required to get cancellations to work on letta code localhost --------- Co-authored-by: Letta <noreply@letta.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
committed by
Caren Thomas
parent
a44c05040a
commit
a6077f3927
@@ -258,6 +258,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
first_chunk = False
|
||||
|
||||
# Check if step was cancelled - break out of the step loop
|
||||
if not self.should_continue:
|
||||
break
|
||||
|
||||
# Proactive summarization if approaching context limit
|
||||
if (
|
||||
self.last_step_usage
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.run import RunUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
@@ -242,6 +243,11 @@ async def create_background_stream_processor(
|
||||
except:
|
||||
pass
|
||||
|
||||
except RunCancelledException as e:
|
||||
# Handle cancellation gracefully - don't write error chunk, cancellation event was already sent
|
||||
logger.info(f"Stream processing stopped due to cancellation for run {run_id}")
|
||||
# The cancellation event was already yielded by cancellation_aware_stream_wrapper
|
||||
# Just mark as complete, don't write additional error chunks
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream for run {run_id}: {e}")
|
||||
# Write error chunk
|
||||
@@ -250,9 +256,8 @@ async def create_background_stream_processor(
|
||||
if run_manager and actor:
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value),
|
||||
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value, metadata={"error": str(e)}),
|
||||
actor=actor,
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
|
||||
error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"}
|
||||
|
||||
@@ -144,7 +144,7 @@ async def cancellation_aware_stream_wrapper(
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - last_cancellation_check >= cancellation_check_interval:
|
||||
try:
|
||||
run = await run_manager.get_run_by_id_async(run_id=run_id, actor=actor)
|
||||
run = await run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if run.status == RunStatus.cancelled:
|
||||
logger.info(f"Stream cancelled for run {run_id}, interrupting stream")
|
||||
# Send cancellation event to client
|
||||
@@ -152,6 +152,9 @@ async def cancellation_aware_stream_wrapper(
|
||||
yield f"data: {json.dumps(cancellation_event)}\n\n"
|
||||
# Raise custom exception for explicit run cancellation
|
||||
raise RunCancelledException(run_id, f"Run {run_id} was cancelled")
|
||||
except RunCancelledException:
|
||||
# Re-raise cancellation immediately, don't catch it
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log warning but don't fail the stream if cancellation check fails
|
||||
logger.warning(f"Failed to check run cancellation for run {run_id}: {e}")
|
||||
|
||||
@@ -37,7 +37,11 @@ from letta.schemas.run import Run as PydanticRun, RunUpdate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
|
||||
from letta.server.rest_api.streaming_response import (
|
||||
StreamingResponseWithStatusCode,
|
||||
add_keepalive_to_stream,
|
||||
cancellation_aware_stream_wrapper,
|
||||
)
|
||||
from letta.services.run_manager import RunManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
@@ -130,9 +134,19 @@ class StreamingService:
|
||||
service_name="redis",
|
||||
)
|
||||
|
||||
# Wrap the agent loop stream with cancellation awareness for background task
|
||||
background_stream = raw_stream
|
||||
if settings.enable_cancellation_aware_streaming and run:
|
||||
background_stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=raw_stream,
|
||||
run_manager=self.runs_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
safe_create_task(
|
||||
create_background_stream_processor(
|
||||
stream_generator=raw_stream,
|
||||
stream_generator=background_stream,
|
||||
redis_client=redis_client,
|
||||
run_id=run.id,
|
||||
run_manager=self.server.run_manager,
|
||||
@@ -146,11 +160,19 @@ class StreamingService:
|
||||
run_id=run.id,
|
||||
)
|
||||
|
||||
# wrap client stream with cancellation awareness if enabled and tracking runs
|
||||
stream = raw_stream
|
||||
if settings.enable_cancellation_aware_streaming and settings.track_agent_run and run and not request.background:
|
||||
stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=raw_stream,
|
||||
run_manager=self.runs_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# conditionally wrap with keepalive based on request parameter
|
||||
if request.include_pings and settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
|
||||
else:
|
||||
stream = raw_stream
|
||||
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
|
||||
|
||||
result = StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
|
||||
@@ -1114,6 +1114,10 @@ def safe_create_task(coro, label: str = "background task"):
|
||||
try:
|
||||
await coro
|
||||
except Exception as e:
|
||||
# Don't log RunCancelledException as an error - it's expected when streams are cancelled
|
||||
if e.__class__.__name__ == "RunCancelledException":
|
||||
logger.info(f"{label} was cancelled (RunCancelledException)")
|
||||
return
|
||||
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
|
||||
|
||||
task = asyncio.create_task(wrapper())
|
||||
|
||||
190
tests/integration_test_cancellation.py
Normal file
190
tests/integration_test_cancellation.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, MessageCreate
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import AgentType, JobStatus
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
||||
filename = os.path.join(llm_config_dir, filename)
|
||||
with open(filename, "r") as f:
|
||||
config_data = json.load(f)
|
||||
llm_config = LLMConfig(**config_data)
|
||||
return llm_config
|
||||
|
||||
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
"openai-o3.json",
|
||||
"openai-gpt-5.json",
|
||||
"claude-4-5-sonnet.json",
|
||||
"claude-4-1-opus.json",
|
||||
"gemini-2.5-flash.json",
|
||||
]
|
||||
|
||||
requested = os.getenv("LLM_CONFIG_FILE")
|
||||
filenames = [requested] if requested else all_configs
|
||||
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
||||
|
||||
|
||||
def roll_dice(num_sides: int) -> int:
|
||||
"""
|
||||
Returns a random number between 1 and num_sides.
|
||||
Args:
|
||||
num_sides (int): The number of sides on the die.
|
||||
Returns:
|
||||
int: A random integer between 1 and num_sides, representing the die roll.
|
||||
"""
|
||||
import random
|
||||
|
||||
return random.randint(1, num_sides)
|
||||
|
||||
|
||||
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def accumulate_chunks(chunks: Any) -> List[Any]:
|
||||
"""
|
||||
Accumulates chunks into a list of messages.
|
||||
"""
|
||||
messages = []
|
||||
current_message = None
|
||||
prev_message_type = None
|
||||
async for chunk in chunks:
|
||||
current_message_type = chunk.message_type
|
||||
if prev_message_type != current_message_type:
|
||||
messages.append(current_message)
|
||||
current_message = chunk
|
||||
else:
|
||||
current_message = chunk
|
||||
prev_message_type = current_message_type
|
||||
messages.append(current_message)
|
||||
return [m for m in messages if m is not None]
|
||||
|
||||
|
||||
async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5):
|
||||
await asyncio.sleep(delay)
|
||||
await client.agents.messages.cancel(agent_id=agent_id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it's accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
timeout_seconds = 30
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
try:
|
||||
response = httpx.get(url + "/v1/health", timeout=1.0)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
raise TimeoutError(f"Server at {url} did not become ready in {timeout_seconds}s")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def client(server_url: str) -> AsyncLetta:
|
||||
"""
|
||||
Creates and returns an asynchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = AsyncLetta(base_url=server_url)
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def agent_state(client: AsyncLetta) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is configured with the roll_dice tool.
|
||||
"""
|
||||
dice_tool = await client.tools.upsert_from_function(func=roll_dice)
|
||||
|
||||
agent_state_instance = await client.agents.create(
|
||||
agent_type=AgentType.letta_v1_agent,
|
||||
name="test_agent",
|
||||
include_base_tools=False,
|
||||
tool_ids=[dice_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tags=["test"],
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
await client.agents.delete(agent_state_instance.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_background_streaming_cancellation(
|
||||
disable_e2b_api_key: Any,
|
||||
client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
delay = 5 if llm_config.model == "gpt-5" else 0.5
|
||||
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
|
||||
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
stream_tokens=True,
|
||||
background=True,
|
||||
)
|
||||
messages = await accumulate_chunks(response)
|
||||
run_id = messages[0].run_id
|
||||
|
||||
await _cancellation_task
|
||||
|
||||
run = await client.runs.retrieve(run_id=run_id)
|
||||
assert run.status == JobStatus.cancelled
|
||||
|
||||
response = client.runs.stream(run_id=run_id, starting_after=0)
|
||||
messages_from_stream = await accumulate_chunks(response)
|
||||
assert len(messages_from_stream) > 0
|
||||
Reference in New Issue
Block a user