diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 575d8115..5da5f566 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -258,6 +258,10 @@ class LettaAgentV3(LettaAgentV2): yield f"data: {chunk.model_dump_json()}\n\n" first_chunk = False + # Check if step was cancelled - break out of the step loop + if not self.should_continue: + break + # Proactive summarization if approaching context limit if ( self.last_step_usage diff --git a/letta/server/rest_api/redis_stream_manager.py b/letta/server/rest_api/redis_stream_manager.py index b361466d..06a4cac9 100644 --- a/letta/server/rest_api/redis_stream_manager.py +++ b/letta/server/rest_api/redis_stream_manager.py @@ -12,6 +12,7 @@ from letta.schemas.enums import RunStatus from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.run import RunUpdate from letta.schemas.user import User +from letta.server.rest_api.streaming_response import RunCancelledException from letta.services.run_manager import RunManager from letta.utils import safe_create_task @@ -242,6 +243,11 @@ async def create_background_stream_processor( except: pass + except RunCancelledException as e: + # Handle cancellation gracefully - don't write error chunk, cancellation event was already sent + logger.info(f"Stream processing stopped due to cancellation for run {run_id}") + # The cancellation event was already yielded by cancellation_aware_stream_wrapper + # Just mark as complete, don't write additional error chunks except Exception as e: logger.error(f"Error processing stream for run {run_id}: {e}") # Write error chunk @@ -250,9 +256,8 @@ async def create_background_stream_processor( if run_manager and actor: await run_manager.update_run_by_id_async( run_id=run_id, - update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value), + update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value, metadata={"error": str(e)}), actor=actor, - metadata={"error": str(e)}, ) error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"} diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index 6dab5f23..5e400d63 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -144,7 +144,7 @@ async def cancellation_aware_stream_wrapper( current_time = asyncio.get_event_loop().time() if current_time - last_cancellation_check >= cancellation_check_interval: try: - run = await run_manager.get_run_by_id_async(run_id=run_id, actor=actor) + run = await run_manager.get_run_by_id(run_id=run_id, actor=actor) if run.status == RunStatus.cancelled: logger.info(f"Stream cancelled for run {run_id}, interrupting stream") # Send cancellation event to client @@ -152,6 +152,9 @@ async def cancellation_aware_stream_wrapper( yield f"data: {json.dumps(cancellation_event)}\n\n" # Raise custom exception for explicit run cancellation raise RunCancelledException(run_id, f"Run {run_id} was cancelled") + except RunCancelledException: + # Re-raise cancellation immediately, don't catch it + raise except Exception as e: # Log warning but don't fail the stream if cancellation check fails logger.warning(f"Failed to check run cancellation for run {run_id}: {e}") diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py index 91d17045..373af5e8 100644 --- a/letta/services/streaming_service.py +++ b/letta/services/streaming_service.py @@ -37,7 +37,11 @@ from letta.schemas.run import Run as PydanticRun, RunUpdate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator -from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream +from letta.server.rest_api.streaming_response import ( + StreamingResponseWithStatusCode, + add_keepalive_to_stream, + cancellation_aware_stream_wrapper, +) from letta.services.run_manager import RunManager from letta.settings import settings from letta.utils import safe_create_task @@ -130,9 +134,19 @@ class StreamingService: service_name="redis", ) + # Wrap the agent loop stream with cancellation awareness for background task + background_stream = raw_stream + if settings.enable_cancellation_aware_streaming and run: + background_stream = cancellation_aware_stream_wrapper( + stream_generator=raw_stream, + run_manager=self.runs_manager, + run_id=run.id, + actor=actor, + ) + safe_create_task( create_background_stream_processor( - stream_generator=raw_stream, + stream_generator=background_stream, redis_client=redis_client, run_id=run.id, run_manager=self.server.run_manager, @@ -146,11 +160,19 @@ class StreamingService: run_id=run.id, ) + # wrap client stream with cancellation awareness if enabled and tracking runs + stream = raw_stream + if settings.enable_cancellation_aware_streaming and settings.track_agent_run and run and not request.background: + stream = cancellation_aware_stream_wrapper( + stream_generator=raw_stream, + run_manager=self.runs_manager, + run_id=run.id, + actor=actor, + ) + # conditionally wrap with keepalive based on request parameter if request.include_pings and settings.enable_keepalive: - stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval, run_id=run.id) - else: - stream = raw_stream + stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id) result = StreamingResponseWithStatusCode( stream, diff --git a/letta/utils.py b/letta/utils.py index 184a98fb..3c529a71 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -1114,6 +1114,10 @@ def safe_create_task(coro, label: str = "background task"): try: await coro except Exception as e: + # Don't log RunCancelledException as an error - it's expected when streams are cancelled + if e.__class__.__name__ == "RunCancelledException": + logger.info(f"{label} was cancelled (RunCancelledException)") + return logger.exception(f"{label} failed with {type(e).__name__}: {e}") task = asyncio.create_task(wrapper()) diff --git a/tests/integration_test_cancellation.py b/tests/integration_test_cancellation.py new file mode 100644 index 00000000..509ad8db --- /dev/null +++ b/tests/integration_test_cancellation.py @@ -0,0 +1,190 @@ +import asyncio +import json +import os +import threading +from typing import Any, List + +import pytest +from dotenv import load_dotenv +from letta_client import AsyncLetta, MessageCreate + +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.enums import AgentType, JobStatus +from letta.schemas.llm_config import LLMConfig + +logger = get_logger(__name__) + + +def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig: + filename = os.path.join(llm_config_dir, filename) + with open(filename, "r") as f: + config_data = json.load(f) + llm_config = LLMConfig(**config_data) + return llm_config + + +all_configs = [ + "openai-gpt-4o-mini.json", + "openai-o3.json", + "openai-gpt-5.json", + "claude-4-5-sonnet.json", + "claude-4-1-opus.json", + "gemini-2.5-flash.json", +] + +requested = os.getenv("LLM_CONFIG_FILE") +filenames = [requested] if requested else all_configs +TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames] + + +def roll_dice(num_sides: int) -> int: + """ + Returns a random number between 1 and num_sides. + Args: + num_sides (int): The number of sides on the die. + Returns: + int: A random integer between 1 and num_sides, representing the die roll. + """ + import random + + return random.randint(1, num_sides) + + +USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [ + MessageCreate( + role="user", + content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.", + ) +] + + +async def accumulate_chunks(chunks: Any) -> List[Any]: + """ + Accumulates chunks into a list of messages. + """ + messages = [] + current_message = None + prev_message_type = None + async for chunk in chunks: + current_message_type = chunk.message_type + if prev_message_type != current_message_type: + messages.append(current_message) + current_message = chunk + else: + current_message = chunk + prev_message_type = current_message_type + messages.append(current_message) + return [m for m in messages if m is not None] + + +async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5): + await asyncio.sleep(delay) + await client.agents.messages.cancel(agent_id=agent_id) + + +@pytest.fixture(scope="function") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it's accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + timeout_seconds = 30 + import time + + import httpx + + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + response = httpx.get(url + "/v1/health", timeout=1.0) + if response.status_code == 200: + break + except Exception: + pass + time.sleep(0.5) + else: + raise TimeoutError(f"Server at {url} did not become ready in {timeout_seconds}s") + + return url + + +@pytest.fixture(scope="function") +async def client(server_url: str) -> AsyncLetta: + """ + Creates and returns an asynchronous Letta REST client for testing. + """ + client_instance = AsyncLetta(base_url=server_url) + yield client_instance + + +@pytest.fixture(scope="function") +async def agent_state(client: AsyncLetta) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + The agent is configured with the roll_dice tool. + """ + dice_tool = await client.tools.upsert_from_function(func=roll_dice) + + agent_state_instance = await client.agents.create( + agent_type=AgentType.letta_v1_agent, + name="test_agent", + include_base_tools=False, + tool_ids=[dice_tool.id], + model="openai/gpt-4o", + embedding="openai/text-embedding-3-small", + tags=["test"], + ) + yield agent_state_instance + + await client.agents.delete(agent_state_instance.id) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +@pytest.mark.asyncio(loop_scope="function") +async def test_background_streaming_cancellation( + disable_e2b_api_key: Any, + client: AsyncLetta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + delay = 5 if llm_config.model == "gpt-5" else 0.5 + _cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay)) + + response = client.agents.messages.create_stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_ROLL_DICE, + stream_tokens=True, + background=True, + ) + messages = await accumulate_chunks(response) + run_id = messages[0].run_id + + await _cancellation_task + + run = await client.runs.retrieve(run_id=run_id) + assert run.status == JobStatus.cancelled + + response = client.runs.stream(run_id=run_id, starting_after=0) + messages_from_stream = await accumulate_chunks(response) + assert len(messages_from_stream) > 0