feat: support for agent loop job cancelation (#2837)

This commit is contained in:
Andy Li
2025-07-02 14:31:16 -07:00
committed by GitHub
parent 243d3d040b
commit 33c1f26ab6
17 changed files with 940 additions and 281 deletions

View File

@@ -1,7 +1,9 @@
import asyncio
import json
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from enum import Enum
from typing import AsyncGenerator, List, Optional, Union
from typing import Optional
from anthropic import AsyncStream
from anthropic.types.beta import (
@@ -131,14 +133,16 @@ class AnthropicStreamingInterface:
self,
stream: AsyncStream[BetaRawMessageStreamEvent],
ttft_span: Optional["Span"] = None,
provider_request_start_timestamp_ns: Optional[int] = None,
) -> AsyncGenerator[LettaMessage, None]:
provider_request_start_timestamp_ns: int | None = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
prev_message_type = None
message_index = 0
first_chunk = True
try:
async with stream:
async for event in stream:
# TODO (cliandy): reconsider in stream cancellations
# await cancellation_token.check_and_raise_if_cancelled()
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - provider_request_start_timestamp_ns
@@ -384,18 +388,21 @@ class AnthropicStreamingInterface:
self.tool_call_buffer = []
self.anthropic_mode = None
except asyncio.CancelledError as e:
logger.info("Cancelled stream %s", e)
yield LettaStopReason(stop_reason=StopReasonType.cancelled)
raise
except Exception as e:
logger.error("Error processing stream: %s", e)
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
yield stop_reason
yield LettaStopReason(stop_reason=StopReasonType.error)
raise
finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
def _process_group(
group: List[Union[ReasoningMessage, HiddenReasoningMessage]], group_type: str
) -> Union[TextContent, ReasoningContent, RedactedReasoningContent]:
group: list[ReasoningMessage | HiddenReasoningMessage], group_type: str
) -> TextContent | ReasoningContent | RedactedReasoningContent:
if group_type == "reasoning":
reasoning_text = "".join(chunk.reasoning for chunk in group).strip()
is_native = any(chunk.source == "reasoner_model" for chunk in group)

View File

@@ -1,4 +1,5 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from collections.abc import AsyncGenerator
from typing import Any
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
@@ -19,14 +20,14 @@ class OpenAIChatCompletionsStreamingInterface:
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.stream_pre_execution_message: bool = stream_pre_execution_message
self.current_parsed_json_result: Dict[str, Any] = {}
self.content_buffer: List[str] = []
self.current_parsed_json_result: dict[str, Any] = {}
self.content_buffer: list[str] = []
self.tool_call_happened: bool = False
self.finish_reason_stop: bool = False
self.tool_call_name: Optional[str] = None
self.tool_call_name: str | None = None
self.tool_call_args_str: str = ""
self.tool_call_id: Optional[str] = None
self.tool_call_id: str | None = None
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[str, None]:
"""
@@ -35,6 +36,8 @@ class OpenAIChatCompletionsStreamingInterface:
"""
async with stream:
async for chunk in stream:
# TODO (cliandy): reconsider in stream cancellations
# await cancellation_token.check_and_raise_if_cancelled()
if chunk.choices:
choice = chunk.choices[0]
delta = choice.delta
@@ -103,7 +106,7 @@ class OpenAIChatCompletionsStreamingInterface:
)
)
def _handle_finish_reason(self, finish_reason: Optional[str]) -> bool:
def _handle_finish_reason(self, finish_reason: str | None) -> bool:
"""Handles the finish reason and determines if streaming should stop."""
if finish_reason == "tool_calls":
self.tool_call_happened = True

View File

@@ -1,5 +1,7 @@
import asyncio
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional
from typing import Optional
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@@ -55,12 +57,12 @@ class OpenAIStreamingInterface:
self.input_tokens = 0
self.output_tokens = 0
self.content_buffer: List[str] = []
self.tool_call_name: Optional[str] = None
self.tool_call_id: Optional[str] = None
self.content_buffer: list[str] = []
self.tool_call_name: str | None = None
self.tool_call_id: str | None = None
self.reasoning_messages = []
def get_reasoning_content(self) -> List[TextContent]:
def get_reasoning_content(self) -> list[TextContent | OmittedReasoningContent]:
content = "".join(self.reasoning_messages).strip()
# Right now we assume that all models omit reasoning content for OAI,
@@ -87,8 +89,8 @@ class OpenAIStreamingInterface:
self,
stream: AsyncStream[ChatCompletionChunk],
ttft_span: Optional["Span"] = None,
provider_request_start_timestamp_ns: Optional[int] = None,
) -> AsyncGenerator[LettaMessage, None]:
provider_request_start_timestamp_ns: int | None = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
"""
Iterates over the OpenAI stream, yielding SSE events.
It also collects tokens and detects if a tool call is triggered.
@@ -99,6 +101,8 @@ class OpenAIStreamingInterface:
prev_message_type = None
message_index = 0
async for chunk in stream:
# TODO (cliandy): reconsider in stream cancellations
# await cancellation_token.check_and_raise_if_cancelled()
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - provider_request_start_timestamp_ns
@@ -224,8 +228,7 @@ class OpenAIStreamingInterface:
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a ToolCallMessage
else:
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
# use_assistant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
if self.use_assistant_message and (
self.last_flushed_function_name is not None
and self.last_flushed_function_name == self.assistant_message_tool_name
@@ -349,10 +352,13 @@ class OpenAIStreamingInterface:
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
self.function_id_buffer = None
except asyncio.CancelledError as e:
logger.info("Cancelled stream %s", e)
yield LettaStopReason(stop_reason=StopReasonType.cancelled)
raise
except Exception as e:
logger.error("Error processing stream: %s", e)
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
yield stop_reason
yield LettaStopReason(stop_reason=StopReasonType.error)
raise
finally:
logger.info("OpenAIStreamingInterface: Stream processing complete.")