feat: add special approval request otid for openai streaming (#5744)
* feat: add special approval request otid for openai streaming * fix import
This commit is contained in:
@@ -52,6 +52,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
from letta.streaming_utils import (
|
||||
FunctionArgumentsStreamHandler,
|
||||
JSONInnerThoughtsExtractor,
|
||||
@@ -325,14 +326,14 @@ class OpenAIStreamingInterface:
|
||||
self.tool_call_name = str(self.function_name_buffer)
|
||||
if self.tool_call_name in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=None,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
@@ -413,7 +414,7 @@ class OpenAIStreamingInterface:
|
||||
message_index += 1
|
||||
if self.function_name_buffer in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
@@ -421,7 +422,7 @@ class OpenAIStreamingInterface:
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
@@ -452,7 +453,7 @@ class OpenAIStreamingInterface:
|
||||
message_index += 1
|
||||
if self.function_name_buffer in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
@@ -460,7 +461,7 @@ class OpenAIStreamingInterface:
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
@@ -603,6 +604,8 @@ class SimpleOpenAIStreamingInterface:
|
||||
# For reasoning models, emit a hidden reasoning message before the first chunk
|
||||
if not self.emitted_hidden_reasoning and is_openai_reasoning_model(self.model):
|
||||
self.emitted_hidden_reasoning = True
|
||||
if prev_message_type and prev_message_type != "hidden_reasoning_message":
|
||||
message_index += 1
|
||||
hidden_message = HiddenReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
@@ -614,7 +617,6 @@ class SimpleOpenAIStreamingInterface:
|
||||
)
|
||||
self.content_messages.append(hidden_message)
|
||||
prev_message_type = hidden_message.message_type
|
||||
message_index += 1 # Increment for the next message
|
||||
yield hidden_message
|
||||
|
||||
async for chunk in stream:
|
||||
@@ -676,6 +678,8 @@ class SimpleOpenAIStreamingInterface:
|
||||
message_delta = choice.delta
|
||||
|
||||
if message_delta.content is not None and message_delta.content != "":
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_msg = AssistantMessage(
|
||||
id=self.letta_message_id,
|
||||
content=message_delta.content,
|
||||
@@ -686,7 +690,6 @@ class SimpleOpenAIStreamingInterface:
|
||||
)
|
||||
self.content_messages.append(assistant_msg)
|
||||
prev_message_type = assistant_msg.message_type
|
||||
message_index += 1
|
||||
yield assistant_msg
|
||||
|
||||
if (
|
||||
@@ -698,6 +701,8 @@ class SimpleOpenAIStreamingInterface:
|
||||
delta = chunk.choices[0].delta
|
||||
reasoning_content = getattr(delta, "reasoning_content", None)
|
||||
if reasoning_content is not None and reasoning_content != "":
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_msg = ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
@@ -710,7 +715,6 @@ class SimpleOpenAIStreamingInterface:
|
||||
)
|
||||
self.content_messages.append(reasoning_msg)
|
||||
prev_message_type = reasoning_msg.message_type
|
||||
message_index += 1
|
||||
yield reasoning_msg
|
||||
|
||||
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
|
||||
@@ -746,7 +750,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
|
||||
if self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=tool_call.function.name,
|
||||
@@ -754,11 +758,13 @@ class SimpleOpenAIStreamingInterface:
|
||||
tool_call_id=tool_call.id,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
@@ -774,8 +780,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
message_index += 1 # Increment for the next message
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
|
||||
@@ -971,11 +976,9 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
# cache for approval if/elses
|
||||
self.tool_call_name = name
|
||||
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
||||
if prev_message_type and prev_message_type != "approval_request_message":
|
||||
message_index += 1
|
||||
yield ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=name,
|
||||
@@ -985,7 +988,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "tool_call_message"
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
@@ -1141,11 +1143,9 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
delta = event.delta
|
||||
|
||||
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
||||
if prev_message_type and prev_message_type != "approval_request_message":
|
||||
message_index += 1
|
||||
yield ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
@@ -1155,7 +1155,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = "approval_request_message"
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
|
||||
Reference in New Issue
Block a user