fix: approval request for streaming (#4445)

* fix: approval request for streaming

* fix: claude code attempt, unit test passing (add on to #4445) (#4448)

* fix: claude code attempt, unit test passing

* chore: update locks to 0.1.314 from 0.1.312

* chore: just stage-api && just publish-api

* chore: drop dead poetry lock

---------

Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
cthomas
2025-09-05 17:43:21 -07:00
committed by GitHub
parent a677095f05
commit cb7296c81d
8 changed files with 191 additions and 88 deletions

View File

@@ -11,6 +11,7 @@ from letta.llm_api.openai_client import is_openai_reasoning_model
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.schemas.letta_message import (
ApprovalRequestMessage,
AssistantMessage,
HiddenReasoningMessage,
LettaMessage,
@@ -43,6 +44,7 @@ class OpenAIStreamingInterface:
messages: Optional[list] = None,
tools: Optional[list] = None,
put_inner_thoughts_in_kwarg: bool = True,
requires_approval_tools: list = [],
):
self.use_assistant_message = use_assistant_message
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
@@ -86,6 +88,8 @@ class OpenAIStreamingInterface:
self.reasoning_messages = []
self.emitted_hidden_reasoning = False # Track if we've emitted hidden reasoning message
self.requires_approval_tools = requires_approval_tools
def get_reasoning_content(self) -> list[TextContent | OmittedReasoningContent]:
content = "".join(self.reasoning_messages).strip()
@@ -274,16 +278,28 @@ class OpenAIStreamingInterface:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
self.tool_call_name = str(self.function_name_buffer)
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
if self.tool_call_name in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
else:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -404,17 +420,30 @@ class OpenAIStreamingInterface:
combined_chunk = self.function_args_buffer + updates_main_json
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
if self.function_name_buffer in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
else:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
# clear buffer
@@ -424,17 +453,30 @@ class OpenAIStreamingInterface:
# If there's no buffer to clear, just output a new chunk with new data
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
if self.function_name_buffer in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
else:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
self.function_id_buffer = None