feat: continue stream processing on client cancel (#3796)

This commit is contained in:
cthomas
2025-08-07 13:17:36 -07:00
committed by GitHub
parent a6ada1e4ea
commit db41f01ac2
3 changed files with 1536 additions and 1433 deletions

View File

@@ -126,271 +126,6 @@ class AnthropicStreamingInterface:
logger.error("Error checking inner thoughts: %s", e)
raise
async def process(
self,
stream: AsyncStream[BetaRawMessageStreamEvent],
ttft_span: Optional["Span"] = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
prev_message_type = None
message_index = 0
try:
async with stream:
async for event in stream:
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
if isinstance(event, BetaRawContentBlockStartEvent):
content = event.content_block
if isinstance(content, BetaTextBlock):
self.anthropic_mode = EventMode.TEXT
# TODO: Can capture citations, etc.
elif isinstance(content, BetaToolUseBlock):
self.anthropic_mode = EventMode.TOOL_USE
self.tool_call_id = content.id
self.tool_call_name = content.name
self.inner_thoughts_complete = False
if not self.use_assistant_message:
# Buffer the initial tool call message instead of yielding immediately
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
self.anthropic_mode = EventMode.THINKING
# TODO: Can capture signature, etc.
elif isinstance(content, BetaRedactedThinkingBlock):
self.anthropic_mode = EventMode.REDACTED_THINKING
if prev_message_type and prev_message_type != "hidden_reasoning_message":
message_index += 1
hidden_reasoning_message = HiddenReasoningMessage(
id=self.letta_message_id,
state="redacted",
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(hidden_reasoning_message)
prev_message_type = hidden_reasoning_message.message_type
yield hidden_reasoning_message
elif isinstance(event, BetaRawContentBlockDeltaEvent):
delta = event.delta
if isinstance(delta, BetaTextDelta):
# Safety check
if not self.anthropic_mode == EventMode.TEXT:
raise RuntimeError(
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
)
# Combine buffer with current text to handle tags split across chunks
combined_text = self.partial_tag_buffer + delta.text
# Remove all occurrences of </thinking> tag
cleaned_text = combined_text.replace("</thinking>", "")
# Extract just the new content (without the buffer part)
if len(self.partial_tag_buffer) <= len(cleaned_text):
delta.text = cleaned_text[len(self.partial_tag_buffer) :]
else:
# Edge case: the tag was removed and now the text is shorter than the buffer
delta.text = ""
# Store the last 10 characters (or all if less than 10) for the next chunk
# This is enough to catch "</thinking" which is 10 characters
self.partial_tag_buffer = combined_text[-10:] if len(combined_text) > 10 else combined_text
self.accumulated_inner_thoughts.append(delta.text)
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
reasoning=self.accumulated_inner_thoughts[-1],
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
elif isinstance(delta, BetaInputJSONDelta):
if not self.anthropic_mode == EventMode.TOOL_USE:
raise RuntimeError(
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
)
self.accumulated_tool_call_args += delta.partial_json
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
# Start detecting a difference in inner thoughts
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
if inner_thoughts_diff:
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
reasoning=inner_thoughts_diff,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
self.inner_thoughts_complete = True
# Flush all buffered tool call messages
if len(self.tool_call_buffer) > 0:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
# Strip out the inner thoughts from the buffered tool call arguments before streaming
tool_call_args = ""
for buffered_msg in self.tool_call_buffer:
tool_call_args += buffered_msg.tool_call.arguments if buffered_msg.tool_call.arguments else ""
tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "")
tool_call_msg = ToolCallMessage(
id=self.tool_call_buffer[0].id,
otid=Message.generate_otid_from_id(self.tool_call_buffer[0].id, message_index),
date=self.tool_call_buffer[0].date,
name=self.tool_call_buffer[0].name,
sender_id=self.tool_call_buffer[0].sender_id,
step_id=self.tool_call_buffer[0].step_id,
tool_call=ToolCallDelta(
name=self.tool_call_name,
tool_call_id=self.tool_call_id,
arguments=tool_call_args,
),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
self.tool_call_buffer = []
# Start detecting special case of "send_message"
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
send_message_diff = current_send_message[len(previous_send_message) :]
# Only stream out if it's not an empty string
if send_message_diff:
if prev_message_type and prev_message_type != "assistant_message":
message_index += 1
assistant_msg = AssistantMessage(
id=self.letta_message_id,
content=[TextContent(text=send_message_diff)],
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = assistant_msg.message_type
yield assistant_msg
else:
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(
name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json
),
date=datetime.now(timezone.utc).isoformat(),
)
if self.inner_thoughts_complete:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg.otid = Message.generate_otid_from_id(self.letta_message_id, message_index)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
else:
self.tool_call_buffer.append(tool_call_msg)
# Set previous parse
self.previous_parse = current_parsed
elif isinstance(delta, BetaThinkingDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
)
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
source="reasoner_model",
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
elif isinstance(delta, BetaSignatureDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
)
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
source="reasoner_model",
reasoning="",
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
elif isinstance(event, BetaRawMessageStartEvent):
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
self.model = event.message.model
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
elif isinstance(event, BetaRawContentBlockStopEvent):
# If we're exiting a tool use block and there are still buffered messages,
# we should flush them now
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
self.anthropic_mode = None
except asyncio.CancelledError as e:
import traceback
logger.error("Cancelled stream %s: %s", e, traceback.format_exc())
ttft_span.add_event(
name="stop_reason",
attributes={"stop_reason": StopReasonType.cancelled.value, "error": str(e), "stacktrace": traceback.format_exc()},
)
raise e
except Exception as e:
import traceback
logger.error("Error processing stream: %s", e, traceback.format_exc())
ttft_span.add_event(
name="stop_reason",
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
)
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
def _process_group(
group: list[ReasoningMessage | HiddenReasoningMessage], group_type: str
@@ -445,3 +180,294 @@ class AnthropicStreamingInterface:
content.text = content.text[:cutoff]
return merged
async def process(
self,
stream: AsyncStream[BetaRawMessageStreamEvent],
ttft_span: Optional["Span"] = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
prev_message_type = None
message_index = 0
event = None
try:
async with stream:
async for event in stream:
try:
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:
if prev_message_type != None:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:
if prev_message_type != None:
message_index += 1
prev_message_type = new_message_type
yield message
# Don't raise the exception here
continue
except Exception as e:
import traceback
logger.error("Error processing stream: %s", e, traceback.format_exc())
ttft_span.add_event(
name="stop_reason",
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
)
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
async def _process_event(
self,
event: BetaRawMessageStreamEvent,
ttft_span: Optional["Span"] = None,
prev_message_type: Optional[str] = None,
message_index: int = 0,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
"""Process a single event from the Anthropic stream and yield any resulting messages.
Args:
event: The event to process
Yields:
Messages generated from processing this event
"""
if isinstance(event, BetaRawContentBlockStartEvent):
content = event.content_block
if isinstance(content, BetaTextBlock):
self.anthropic_mode = EventMode.TEXT
# TODO: Can capture citations, etc.
elif isinstance(content, BetaToolUseBlock):
self.anthropic_mode = EventMode.TOOL_USE
self.tool_call_id = content.id
self.tool_call_name = content.name
self.inner_thoughts_complete = False
if not self.use_assistant_message:
# Buffer the initial tool call message instead of yielding immediately
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
self.anthropic_mode = EventMode.THINKING
# TODO: Can capture signature, etc.
elif isinstance(content, BetaRedactedThinkingBlock):
self.anthropic_mode = EventMode.REDACTED_THINKING
if prev_message_type and prev_message_type != "hidden_reasoning_message":
message_index += 1
hidden_reasoning_message = HiddenReasoningMessage(
id=self.letta_message_id,
state="redacted",
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(hidden_reasoning_message)
prev_message_type = hidden_reasoning_message.message_type
yield hidden_reasoning_message
elif isinstance(event, BetaRawContentBlockDeltaEvent):
delta = event.delta
if isinstance(delta, BetaTextDelta):
# Safety check
if not self.anthropic_mode == EventMode.TEXT:
raise RuntimeError(f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}")
# Combine buffer with current text to handle tags split across chunks
combined_text = self.partial_tag_buffer + delta.text
# Remove all occurrences of </thinking> tag
cleaned_text = combined_text.replace("</thinking>", "")
# Extract just the new content (without the buffer part)
if len(self.partial_tag_buffer) <= len(cleaned_text):
delta.text = cleaned_text[len(self.partial_tag_buffer) :]
else:
# Edge case: the tag was removed and now the text is shorter than the buffer
delta.text = ""
# Store the last 10 characters (or all if less than 10) for the next chunk
# This is enough to catch "</thinking" which is 10 characters
self.partial_tag_buffer = combined_text[-10:] if len(combined_text) > 10 else combined_text
self.accumulated_inner_thoughts.append(delta.text)
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
reasoning=self.accumulated_inner_thoughts[-1],
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
elif isinstance(delta, BetaInputJSONDelta):
if not self.anthropic_mode == EventMode.TOOL_USE:
raise RuntimeError(
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
)
self.accumulated_tool_call_args += delta.partial_json
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
# Start detecting a difference in inner thoughts
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
if inner_thoughts_diff:
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
reasoning=inner_thoughts_diff,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
self.inner_thoughts_complete = True
# Flush all buffered tool call messages
if len(self.tool_call_buffer) > 0:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
# Strip out the inner thoughts from the buffered tool call arguments before streaming
tool_call_args = ""
for buffered_msg in self.tool_call_buffer:
tool_call_args += buffered_msg.tool_call.arguments if buffered_msg.tool_call.arguments else ""
tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "")
tool_call_msg = ToolCallMessage(
id=self.tool_call_buffer[0].id,
otid=Message.generate_otid_from_id(self.tool_call_buffer[0].id, message_index),
date=self.tool_call_buffer[0].date,
name=self.tool_call_buffer[0].name,
sender_id=self.tool_call_buffer[0].sender_id,
step_id=self.tool_call_buffer[0].step_id,
tool_call=ToolCallDelta(
name=self.tool_call_name,
tool_call_id=self.tool_call_id,
arguments=tool_call_args,
),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
self.tool_call_buffer = []
# Start detecting special case of "send_message"
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
send_message_diff = current_send_message[len(previous_send_message) :]
# Only stream out if it's not an empty string
if send_message_diff:
if prev_message_type and prev_message_type != "assistant_message":
message_index += 1
assistant_msg = AssistantMessage(
id=self.letta_message_id,
content=[TextContent(text=send_message_diff)],
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = assistant_msg.message_type
yield assistant_msg
else:
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
if self.inner_thoughts_complete:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg.otid = Message.generate_otid_from_id(self.letta_message_id, message_index)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
else:
self.tool_call_buffer.append(tool_call_msg)
# Set previous parse
self.previous_parse = current_parsed
elif isinstance(delta, BetaThinkingDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
)
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
source="reasoner_model",
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
elif isinstance(delta, BetaSignatureDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
)
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
reasoning_message = ReasoningMessage(
id=self.letta_message_id,
source="reasoner_model",
reasoning="",
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
yield reasoning_message
elif isinstance(event, BetaRawMessageStartEvent):
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
self.model = event.message.model
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
elif isinstance(event, BetaRawContentBlockStopEvent):
# If we're exiting a tool use block and there are still buffered messages,
# we should flush them now
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
self.anthropic_mode = None