feat: support for agent loop job cancelation (#2837)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import (
|
||||
@@ -131,14 +133,16 @@ class AnthropicStreamingInterface:
|
||||
self,
|
||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||
ttft_span: Optional["Span"] = None,
|
||||
provider_request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> AsyncGenerator[LettaMessage, None]:
|
||||
provider_request_start_timestamp_ns: int | None = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
first_chunk = True
|
||||
try:
|
||||
async with stream:
|
||||
async for event in stream:
|
||||
# TODO (cliandy): reconsider in stream cancellations
|
||||
# await cancellation_token.check_and_raise_if_cancelled()
|
||||
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
|
||||
now = get_utc_timestamp_ns()
|
||||
ttft_ns = now - provider_request_start_timestamp_ns
|
||||
@@ -384,18 +388,21 @@ class AnthropicStreamingInterface:
|
||||
self.tool_call_buffer = []
|
||||
|
||||
self.anthropic_mode = None
|
||||
except asyncio.CancelledError as e:
|
||||
logger.info("Cancelled stream %s", e)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.cancelled)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error processing stream: %s", e)
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
yield stop_reason
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise
|
||||
finally:
|
||||
logger.info("AnthropicStreamingInterface: Stream processing complete.")
|
||||
|
||||
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
|
||||
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
|
||||
def _process_group(
|
||||
group: List[Union[ReasoningMessage, HiddenReasoningMessage]], group_type: str
|
||||
) -> Union[TextContent, ReasoningContent, RedactedReasoningContent]:
|
||||
group: list[ReasoningMessage | HiddenReasoningMessage], group_type: str
|
||||
) -> TextContent | ReasoningContent | RedactedReasoningContent:
|
||||
if group_type == "reasoning":
|
||||
reasoning_text = "".join(chunk.reasoning for chunk in group).strip()
|
||||
is_native = any(chunk.source == "reasoner_model" for chunk in group)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
@@ -19,14 +20,14 @@ class OpenAIChatCompletionsStreamingInterface:
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.stream_pre_execution_message: bool = stream_pre_execution_message
|
||||
|
||||
self.current_parsed_json_result: Dict[str, Any] = {}
|
||||
self.content_buffer: List[str] = []
|
||||
self.current_parsed_json_result: dict[str, Any] = {}
|
||||
self.content_buffer: list[str] = []
|
||||
self.tool_call_happened: bool = False
|
||||
self.finish_reason_stop: bool = False
|
||||
|
||||
self.tool_call_name: Optional[str] = None
|
||||
self.tool_call_name: str | None = None
|
||||
self.tool_call_args_str: str = ""
|
||||
self.tool_call_id: Optional[str] = None
|
||||
self.tool_call_id: str | None = None
|
||||
|
||||
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
@@ -35,6 +36,8 @@ class OpenAIChatCompletionsStreamingInterface:
|
||||
"""
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
# TODO (cliandy): reconsider in stream cancellations
|
||||
# await cancellation_token.check_and_raise_if_cancelled()
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
@@ -103,7 +106,7 @@ class OpenAIChatCompletionsStreamingInterface:
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_finish_reason(self, finish_reason: Optional[str]) -> bool:
|
||||
def _handle_finish_reason(self, finish_reason: str | None) -> bool:
|
||||
"""Handles the finish reason and determines if streaming should stop."""
|
||||
if finish_reason == "tool_calls":
|
||||
self.tool_call_happened = True
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
@@ -55,12 +57,12 @@ class OpenAIStreamingInterface:
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
|
||||
self.content_buffer: List[str] = []
|
||||
self.tool_call_name: Optional[str] = None
|
||||
self.tool_call_id: Optional[str] = None
|
||||
self.content_buffer: list[str] = []
|
||||
self.tool_call_name: str | None = None
|
||||
self.tool_call_id: str | None = None
|
||||
self.reasoning_messages = []
|
||||
|
||||
def get_reasoning_content(self) -> List[TextContent]:
|
||||
def get_reasoning_content(self) -> list[TextContent | OmittedReasoningContent]:
|
||||
content = "".join(self.reasoning_messages).strip()
|
||||
|
||||
# Right now we assume that all models omit reasoning content for OAI,
|
||||
@@ -87,8 +89,8 @@ class OpenAIStreamingInterface:
|
||||
self,
|
||||
stream: AsyncStream[ChatCompletionChunk],
|
||||
ttft_span: Optional["Span"] = None,
|
||||
provider_request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> AsyncGenerator[LettaMessage, None]:
|
||||
provider_request_start_timestamp_ns: int | None = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
"""
|
||||
Iterates over the OpenAI stream, yielding SSE events.
|
||||
It also collects tokens and detects if a tool call is triggered.
|
||||
@@ -99,6 +101,8 @@ class OpenAIStreamingInterface:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
async for chunk in stream:
|
||||
# TODO (cliandy): reconsider in stream cancellations
|
||||
# await cancellation_token.check_and_raise_if_cancelled()
|
||||
if first_chunk and ttft_span is not None and provider_request_start_timestamp_ns is not None:
|
||||
now = get_utc_timestamp_ns()
|
||||
ttft_ns = now - provider_request_start_timestamp_ns
|
||||
@@ -224,8 +228,7 @@ class OpenAIStreamingInterface:
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a ToolCallMessage
|
||||
else:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
# use_assistant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and (
|
||||
self.last_flushed_function_name is not None
|
||||
and self.last_flushed_function_name == self.assistant_message_tool_name
|
||||
@@ -349,10 +352,13 @@ class OpenAIStreamingInterface:
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.function_id_buffer = None
|
||||
except asyncio.CancelledError as e:
|
||||
logger.info("Cancelled stream %s", e)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.cancelled)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error processing stream: %s", e)
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
yield stop_reason
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise
|
||||
finally:
|
||||
logger.info("OpenAIStreamingInterface: Stream processing complete.")
|
||||
|
||||
Reference in New Issue
Block a user