feat: continue stream processing on client cancel (#3796)
This commit is contained in:
@@ -126,271 +126,6 @@ class AnthropicStreamingInterface:
|
||||
logger.error("Error checking inner thoughts: %s", e)
|
||||
raise
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
try:
|
||||
async with stream:
|
||||
async for event in stream:
|
||||
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
|
||||
if isinstance(event, BetaRawContentBlockStartEvent):
|
||||
content = event.content_block
|
||||
|
||||
if isinstance(content, BetaTextBlock):
|
||||
self.anthropic_mode = EventMode.TEXT
|
||||
# TODO: Can capture citations, etc.
|
||||
elif isinstance(content, BetaToolUseBlock):
|
||||
self.anthropic_mode = EventMode.TOOL_USE
|
||||
self.tool_call_id = content.id
|
||||
self.tool_call_name = content.name
|
||||
self.inner_thoughts_complete = False
|
||||
|
||||
if not self.use_assistant_message:
|
||||
# Buffer the initial tool call message instead of yielding immediately
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
elif isinstance(content, BetaThinkingBlock):
|
||||
self.anthropic_mode = EventMode.THINKING
|
||||
# TODO: Can capture signature, etc.
|
||||
elif isinstance(content, BetaRedactedThinkingBlock):
|
||||
self.anthropic_mode = EventMode.REDACTED_THINKING
|
||||
if prev_message_type and prev_message_type != "hidden_reasoning_message":
|
||||
message_index += 1
|
||||
hidden_reasoning_message = HiddenReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
state="redacted",
|
||||
hidden_reasoning=content.data,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(hidden_reasoning_message)
|
||||
prev_message_type = hidden_reasoning_message.message_type
|
||||
yield hidden_reasoning_message
|
||||
|
||||
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
||||
delta = event.delta
|
||||
|
||||
if isinstance(delta, BetaTextDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.TEXT:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
|
||||
)
|
||||
|
||||
# Combine buffer with current text to handle tags split across chunks
|
||||
combined_text = self.partial_tag_buffer + delta.text
|
||||
|
||||
# Remove all occurrences of </thinking> tag
|
||||
cleaned_text = combined_text.replace("</thinking>", "")
|
||||
|
||||
# Extract just the new content (without the buffer part)
|
||||
if len(self.partial_tag_buffer) <= len(cleaned_text):
|
||||
delta.text = cleaned_text[len(self.partial_tag_buffer) :]
|
||||
else:
|
||||
# Edge case: the tag was removed and now the text is shorter than the buffer
|
||||
delta.text = ""
|
||||
|
||||
# Store the last 10 characters (or all if less than 10) for the next chunk
|
||||
# This is enough to catch "</thinking" which is 10 characters
|
||||
self.partial_tag_buffer = combined_text[-10:] if len(combined_text) > 10 else combined_text
|
||||
self.accumulated_inner_thoughts.append(delta.text)
|
||||
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
reasoning=self.accumulated_inner_thoughts[-1],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
elif isinstance(delta, BetaInputJSONDelta):
|
||||
if not self.anthropic_mode == EventMode.TOOL_USE:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
|
||||
)
|
||||
|
||||
self.accumulated_tool_call_args += delta.partial_json
|
||||
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
|
||||
|
||||
# Start detecting a difference in inner thoughts
|
||||
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
|
||||
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
|
||||
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
|
||||
|
||||
if inner_thoughts_diff:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
reasoning=inner_thoughts_diff,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
# Check if inner thoughts are complete - if so, flush the buffer
|
||||
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
|
||||
self.inner_thoughts_complete = True
|
||||
# Flush all buffered tool call messages
|
||||
if len(self.tool_call_buffer) > 0:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
|
||||
# Strip out the inner thoughts from the buffered tool call arguments before streaming
|
||||
tool_call_args = ""
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
tool_call_args += buffered_msg.tool_call.arguments if buffered_msg.tool_call.arguments else ""
|
||||
tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "")
|
||||
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.tool_call_buffer[0].id,
|
||||
otid=Message.generate_otid_from_id(self.tool_call_buffer[0].id, message_index),
|
||||
date=self.tool_call_buffer[0].date,
|
||||
name=self.tool_call_buffer[0].name,
|
||||
sender_id=self.tool_call_buffer[0].sender_id,
|
||||
step_id=self.tool_call_buffer[0].step_id,
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.tool_call_name,
|
||||
tool_call_id=self.tool_call_id,
|
||||
arguments=tool_call_args,
|
||||
),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
# Start detecting special case of "send_message"
|
||||
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
|
||||
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
send_message_diff = current_send_message[len(previous_send_message) :]
|
||||
|
||||
# Only stream out if it's not an empty string
|
||||
if send_message_diff:
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_msg = AssistantMessage(
|
||||
id=self.letta_message_id,
|
||||
content=[TextContent(text=send_message_diff)],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_msg.message_type
|
||||
yield assistant_msg
|
||||
else:
|
||||
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json
|
||||
),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
if self.inner_thoughts_complete:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg.otid = Message.generate_otid_from_id(self.letta_message_id, message_index)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
else:
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
|
||||
# Set previous parse
|
||||
self.previous_parse = current_parsed
|
||||
elif isinstance(delta, BetaThinkingDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning=delta.thinking,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
elif isinstance(delta, BetaSignatureDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning="",
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
signature=delta.signature,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
elif isinstance(event, BetaRawMessageStartEvent):
|
||||
self.message_id = event.message.id
|
||||
self.input_tokens += event.message.usage.input_tokens
|
||||
self.output_tokens += event.message.usage.output_tokens
|
||||
self.model = event.message.model
|
||||
elif isinstance(event, BetaRawMessageDeltaEvent):
|
||||
self.output_tokens += event.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageStopEvent):
|
||||
# Don't do anything here! We don't want to stop the stream.
|
||||
pass
|
||||
elif isinstance(event, BetaRawContentBlockStopEvent):
|
||||
# If we're exiting a tool use block and there are still buffered messages,
|
||||
# we should flush them now
|
||||
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
self.anthropic_mode = None
|
||||
except asyncio.CancelledError as e:
|
||||
import traceback
|
||||
|
||||
logger.error("Cancelled stream %s: %s", e, traceback.format_exc())
|
||||
ttft_span.add_event(
|
||||
name="stop_reason",
|
||||
attributes={"stop_reason": StopReasonType.cancelled.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error("Error processing stream: %s", e, traceback.format_exc())
|
||||
ttft_span.add_event(
|
||||
name="stop_reason",
|
||||
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
||||
)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
logger.info("AnthropicStreamingInterface: Stream processing complete.")
|
||||
|
||||
def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]:
|
||||
def _process_group(
|
||||
group: list[ReasoningMessage | HiddenReasoningMessage], group_type: str
|
||||
@@ -445,3 +180,294 @@ class AnthropicStreamingInterface:
|
||||
content.text = content.text[:cutoff]
|
||||
|
||||
return merged
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
event = None
|
||||
try:
|
||||
async with stream:
|
||||
async for event in stream:
|
||||
try:
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
if prev_message_type != None:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
if prev_message_type != None:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
|
||||
# Don't raise the exception here
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error("Error processing stream: %s", e, traceback.format_exc())
|
||||
ttft_span.add_event(
|
||||
name="stop_reason",
|
||||
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
||||
)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
logger.info("AnthropicStreamingInterface: Stream processing complete.")
|
||||
|
||||
async def _process_event(
|
||||
self,
|
||||
event: BetaRawMessageStreamEvent,
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
"""Process a single event from the Anthropic stream and yield any resulting messages.
|
||||
|
||||
Args:
|
||||
event: The event to process
|
||||
|
||||
Yields:
|
||||
Messages generated from processing this event
|
||||
"""
|
||||
if isinstance(event, BetaRawContentBlockStartEvent):
|
||||
content = event.content_block
|
||||
|
||||
if isinstance(content, BetaTextBlock):
|
||||
self.anthropic_mode = EventMode.TEXT
|
||||
# TODO: Can capture citations, etc.
|
||||
elif isinstance(content, BetaToolUseBlock):
|
||||
self.anthropic_mode = EventMode.TOOL_USE
|
||||
self.tool_call_id = content.id
|
||||
self.tool_call_name = content.name
|
||||
self.inner_thoughts_complete = False
|
||||
|
||||
if not self.use_assistant_message:
|
||||
# Buffer the initial tool call message instead of yielding immediately
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
elif isinstance(content, BetaThinkingBlock):
|
||||
self.anthropic_mode = EventMode.THINKING
|
||||
# TODO: Can capture signature, etc.
|
||||
elif isinstance(content, BetaRedactedThinkingBlock):
|
||||
self.anthropic_mode = EventMode.REDACTED_THINKING
|
||||
if prev_message_type and prev_message_type != "hidden_reasoning_message":
|
||||
message_index += 1
|
||||
hidden_reasoning_message = HiddenReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
state="redacted",
|
||||
hidden_reasoning=content.data,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(hidden_reasoning_message)
|
||||
prev_message_type = hidden_reasoning_message.message_type
|
||||
yield hidden_reasoning_message
|
||||
|
||||
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
||||
delta = event.delta
|
||||
|
||||
if isinstance(delta, BetaTextDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.TEXT:
|
||||
raise RuntimeError(f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}")
|
||||
|
||||
# Combine buffer with current text to handle tags split across chunks
|
||||
combined_text = self.partial_tag_buffer + delta.text
|
||||
|
||||
# Remove all occurrences of </thinking> tag
|
||||
cleaned_text = combined_text.replace("</thinking>", "")
|
||||
|
||||
# Extract just the new content (without the buffer part)
|
||||
if len(self.partial_tag_buffer) <= len(cleaned_text):
|
||||
delta.text = cleaned_text[len(self.partial_tag_buffer) :]
|
||||
else:
|
||||
# Edge case: the tag was removed and now the text is shorter than the buffer
|
||||
delta.text = ""
|
||||
|
||||
# Store the last 10 characters (or all if less than 10) for the next chunk
|
||||
# This is enough to catch "</thinking" which is 10 characters
|
||||
self.partial_tag_buffer = combined_text[-10:] if len(combined_text) > 10 else combined_text
|
||||
self.accumulated_inner_thoughts.append(delta.text)
|
||||
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
reasoning=self.accumulated_inner_thoughts[-1],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
elif isinstance(delta, BetaInputJSONDelta):
|
||||
if not self.anthropic_mode == EventMode.TOOL_USE:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
|
||||
)
|
||||
|
||||
self.accumulated_tool_call_args += delta.partial_json
|
||||
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
|
||||
|
||||
# Start detecting a difference in inner thoughts
|
||||
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
|
||||
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
|
||||
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
|
||||
|
||||
if inner_thoughts_diff:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
reasoning=inner_thoughts_diff,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
# Check if inner thoughts are complete - if so, flush the buffer
|
||||
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
|
||||
self.inner_thoughts_complete = True
|
||||
# Flush all buffered tool call messages
|
||||
if len(self.tool_call_buffer) > 0:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
|
||||
# Strip out the inner thoughts from the buffered tool call arguments before streaming
|
||||
tool_call_args = ""
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
tool_call_args += buffered_msg.tool_call.arguments if buffered_msg.tool_call.arguments else ""
|
||||
tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "")
|
||||
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.tool_call_buffer[0].id,
|
||||
otid=Message.generate_otid_from_id(self.tool_call_buffer[0].id, message_index),
|
||||
date=self.tool_call_buffer[0].date,
|
||||
name=self.tool_call_buffer[0].name,
|
||||
sender_id=self.tool_call_buffer[0].sender_id,
|
||||
step_id=self.tool_call_buffer[0].step_id,
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.tool_call_name,
|
||||
tool_call_id=self.tool_call_id,
|
||||
arguments=tool_call_args,
|
||||
),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
# Start detecting special case of "send_message"
|
||||
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
|
||||
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
send_message_diff = current_send_message[len(previous_send_message) :]
|
||||
|
||||
# Only stream out if it's not an empty string
|
||||
if send_message_diff:
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_msg = AssistantMessage(
|
||||
id=self.letta_message_id,
|
||||
content=[TextContent(text=send_message_diff)],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_msg.message_type
|
||||
yield assistant_msg
|
||||
else:
|
||||
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
if self.inner_thoughts_complete:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg.otid = Message.generate_otid_from_id(self.letta_message_id, message_index)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
else:
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
|
||||
# Set previous parse
|
||||
self.previous_parse = current_parsed
|
||||
elif isinstance(delta, BetaThinkingDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning=delta.thinking,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
elif isinstance(delta, BetaSignatureDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning="",
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
signature=delta.signature,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
elif isinstance(event, BetaRawMessageStartEvent):
|
||||
self.message_id = event.message.id
|
||||
self.input_tokens += event.message.usage.input_tokens
|
||||
self.output_tokens += event.message.usage.output_tokens
|
||||
self.model = event.message.model
|
||||
elif isinstance(event, BetaRawMessageDeltaEvent):
|
||||
self.output_tokens += event.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageStopEvent):
|
||||
# Don't do anything here! We don't want to stop the stream.
|
||||
pass
|
||||
elif isinstance(event, BetaRawContentBlockStopEvent):
|
||||
# If we're exiting a tool use block and there are still buffered messages,
|
||||
# we should flush them now
|
||||
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
self.anthropic_mode = None
|
||||
|
||||
@@ -120,260 +120,34 @@ class OpenAIStreamingInterface:
|
||||
tool_dicts = [tool["function"] if isinstance(tool, dict) and "function" in tool else tool for tool in self.tools]
|
||||
self.fallback_input_tokens += num_tokens_from_functions(tool_dicts)
|
||||
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
try:
|
||||
async with stream:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
async for chunk in stream:
|
||||
if not self.model or not self.message_id:
|
||||
self.model = chunk.model
|
||||
self.message_id = chunk.id
|
||||
try:
|
||||
async for message in self._process_chunk(chunk, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
if prev_message_type != None:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
import traceback
|
||||
|
||||
# track usage
|
||||
if chunk.usage:
|
||||
self.input_tokens += chunk.usage.prompt_tokens
|
||||
self.output_tokens += chunk.usage.completion_tokens
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
async for message in self._process_chunk(chunk, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
if prev_message_type != None:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
message_delta = choice.delta
|
||||
# Don't raise the exception here
|
||||
continue
|
||||
|
||||
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
|
||||
tool_call = message_delta.tool_calls[0]
|
||||
|
||||
if tool_call.function.name:
|
||||
# If we're waiting for the first key, then we should hold back the name
|
||||
# ie add it to a buffer instead of returning it as a chunk
|
||||
if self.function_name_buffer is None:
|
||||
self.function_name_buffer = tool_call.function.name
|
||||
else:
|
||||
self.function_name_buffer += tool_call.function.name
|
||||
|
||||
if tool_call.id:
|
||||
# Buffer until next time
|
||||
if self.function_id_buffer is None:
|
||||
self.function_id_buffer = tool_call.id
|
||||
else:
|
||||
self.function_id_buffer += tool_call.id
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
if self.is_openai_proxy:
|
||||
self.fallback_output_tokens += count_tokens(tool_call.function.arguments)
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
if updates_inner_thoughts:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
self.reasoning_messages.append(updates_inner_thoughts)
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
reasoning=updates_inner_thoughts,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
# Additionally inner thoughts may stream back with a chunk of main JSON
|
||||
# In that case, since we can only return a chunk at a time, we should buffer it
|
||||
if updates_main_json:
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If we have main_json, we should output a ToolCallMessage
|
||||
elif updates_main_json:
|
||||
|
||||
# If there's something in the function_name buffer, we should release it first
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
# safe we'll output name first in a separate chunk
|
||||
if self.function_name_buffer:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
self.tool_call_name = str(self.function_name_buffer)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=None,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
# Record what the last function name we flushed was
|
||||
self.last_flushed_function_name = self.function_name_buffer
|
||||
if self.last_flushed_function_id is None:
|
||||
self.last_flushed_function_id = self.function_id_buffer
|
||||
# Clear the buffer
|
||||
self.function_name_buffer = None
|
||||
self.function_id_buffer = None
|
||||
# Since we're clearing the name buffer, we should store
|
||||
# any updates to the arguments inside a separate buffer
|
||||
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a ToolCallMessage
|
||||
else:
|
||||
# use_assistant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and (
|
||||
self.last_flushed_function_name is not None
|
||||
and self.last_flushed_function_name == self.assistant_message_tool_name
|
||||
):
|
||||
# do an additional parse on the updates_main_json
|
||||
if self.function_args_buffer:
|
||||
updates_main_json = self.function_args_buffer + updates_main_json
|
||||
self.function_args_buffer = None
|
||||
|
||||
# Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix
|
||||
match_str = '{"' + self.assistant_message_tool_kwarg + '":"'
|
||||
if updates_main_json == match_str:
|
||||
updates_main_json = None
|
||||
|
||||
else:
|
||||
# Some hardcoding to strip off the trailing "}"
|
||||
if updates_main_json in ["}", '"}']:
|
||||
updates_main_json = None
|
||||
if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"':
|
||||
updates_main_json = updates_main_json[:-1]
|
||||
|
||||
if not updates_main_json:
|
||||
# early exit to turn into content mode
|
||||
continue
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_message = AssistantMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
content=combined_chunk,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
# TODO: THIS IS HORRIBLE
|
||||
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
|
||||
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
|
||||
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
||||
|
||||
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
||||
self.assistant_message_tool_kwarg
|
||||
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
|
||||
new_content = parsed_args.get(self.assistant_message_tool_kwarg)
|
||||
prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_message = AssistantMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
content=diff,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
# clear buffers
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.function_id_buffer = None
|
||||
except asyncio.CancelledError as e:
|
||||
import traceback
|
||||
|
||||
logger.error("Cancelled stream %s: %s", e, traceback.format_exc())
|
||||
ttft_span.add_event(
|
||||
name="stop_reason",
|
||||
attributes={"stop_reason": StopReasonType.cancelled.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
@@ -386,3 +160,249 @@ class OpenAIStreamingInterface:
|
||||
raise e
|
||||
finally:
|
||||
logger.info("OpenAIStreamingInterface: Stream processing complete.")
|
||||
|
||||
async def _process_chunk(
|
||||
self,
|
||||
chunk: ChatCompletionChunk,
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
if not self.model or not self.message_id:
|
||||
self.model = chunk.model
|
||||
self.message_id = chunk.id
|
||||
|
||||
# track usage
|
||||
if chunk.usage:
|
||||
self.input_tokens += chunk.usage.prompt_tokens
|
||||
self.output_tokens += chunk.usage.completion_tokens
|
||||
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
message_delta = choice.delta
|
||||
|
||||
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
|
||||
tool_call = message_delta.tool_calls[0]
|
||||
|
||||
if tool_call.function.name:
|
||||
# If we're waiting for the first key, then we should hold back the name
|
||||
# ie add it to a buffer instead of returning it as a chunk
|
||||
if self.function_name_buffer is None:
|
||||
self.function_name_buffer = tool_call.function.name
|
||||
else:
|
||||
self.function_name_buffer += tool_call.function.name
|
||||
|
||||
if tool_call.id:
|
||||
# Buffer until next time
|
||||
if self.function_id_buffer is None:
|
||||
self.function_id_buffer = tool_call.id
|
||||
else:
|
||||
self.function_id_buffer += tool_call.id
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
|
||||
if self.is_openai_proxy:
|
||||
self.fallback_output_tokens += count_tokens(tool_call.function.arguments)
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
if updates_inner_thoughts:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
self.reasoning_messages.append(updates_inner_thoughts)
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
reasoning=updates_inner_thoughts,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
# Additionally inner thoughts may stream back with a chunk of main JSON
|
||||
# In that case, since we can only return a chunk at a time, we should buffer it
|
||||
if updates_main_json:
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If we have main_json, we should output a ToolCallMessage
|
||||
elif updates_main_json:
|
||||
|
||||
# If there's something in the function_name buffer, we should release it first
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
# safe we'll output name first in a separate chunk
|
||||
if self.function_name_buffer:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
self.tool_call_name = str(self.function_name_buffer)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=None,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
# Record what the last function name we flushed was
|
||||
self.last_flushed_function_name = self.function_name_buffer
|
||||
if self.last_flushed_function_id is None:
|
||||
self.last_flushed_function_id = self.function_id_buffer
|
||||
# Clear the buffer
|
||||
self.function_name_buffer = None
|
||||
self.function_id_buffer = None
|
||||
# Since we're clearing the name buffer, we should store
|
||||
# any updates to the arguments inside a separate buffer
|
||||
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a ToolCallMessage
|
||||
else:
|
||||
# use_assistant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and (
|
||||
self.last_flushed_function_name is not None
|
||||
and self.last_flushed_function_name == self.assistant_message_tool_name
|
||||
):
|
||||
# do an additional parse on the updates_main_json
|
||||
if self.function_args_buffer:
|
||||
updates_main_json = self.function_args_buffer + updates_main_json
|
||||
self.function_args_buffer = None
|
||||
|
||||
# Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix
|
||||
match_str = '{"' + self.assistant_message_tool_kwarg + '":"'
|
||||
if updates_main_json == match_str:
|
||||
updates_main_json = None
|
||||
|
||||
else:
|
||||
# Some hardcoding to strip off the trailing "}"
|
||||
if updates_main_json in ["}", '"}']:
|
||||
updates_main_json = None
|
||||
if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"':
|
||||
updates_main_json = updates_main_json[:-1]
|
||||
|
||||
if not updates_main_json:
|
||||
# early exit to turn into content mode
|
||||
pass
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_message = AssistantMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
content=combined_chunk,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
# TODO: THIS IS HORRIBLE
|
||||
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
|
||||
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
|
||||
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
||||
|
||||
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
||||
self.assistant_message_tool_kwarg
|
||||
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
|
||||
new_content = parsed_args.get(self.assistant_message_tool_kwarg)
|
||||
prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_message = AssistantMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
content=diff,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
# clear buffers
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.function_id_buffer = None
|
||||
|
||||
1897
poetry.lock
generated
1897
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user