fix: approval request for streaming (#4445)

* fix: approval request for streaming

* fix: claude code attempt, unit test passing (add on to #4445) (#4448)

* fix: claude code attempt, unit test passing

* chore: update locks to 0.1.314 from 0.1.312

* chore: just stage-api && just publish-api

* chore: drop dead poetry lock

---------

Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
cthomas
2025-09-05 17:43:21 -07:00
committed by GitHub
parent a677095f05
commit cb7296c81d
8 changed files with 191 additions and 88 deletions

View File

@@ -1022,6 +1022,7 @@ class LettaAgent(BaseAgent):
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
requires_approval_tools=tool_rules_solver.get_requires_approval_tools(valid_tool_names),
)
elif agent_state.llm_config.model_endpoint_type == ProviderType.openai:
interface = OpenAIStreamingInterface(
@@ -1030,6 +1031,7 @@ class LettaAgent(BaseAgent):
messages=current_in_context_messages + new_in_context_messages,
tools=request_data.get("tools", []),
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
requires_approval_tools=tool_rules_solver.get_requires_approval_tools(valid_tool_names),
)
else:
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
@@ -1174,12 +1176,13 @@ class LettaAgent(BaseAgent):
)
step_progression = StepProgression.LOGGED_TRACE
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
if not (use_assistant_message and tool_return.name == "send_message"):
# Apply message type filtering if specified
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
if persisted_messages[-1].role != "approval":
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
if not (use_assistant_message and tool_return.name == "send_message"):
# Apply message type filtering if specified
if include_return_message_types is None or tool_return.message_type in include_return_message_types:
yield f"data: {tool_return.model_dump_json()}\n\n"
# TODO (cliandy): consolidate and expand with trace
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
@@ -1692,7 +1695,6 @@ class LettaAgent(BaseAgent):
tool_call_id=tool_call_id,
request_heartbeat=request_heartbeat,
)
if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name):
approval_message = create_approval_request_message_from_llm_response(
agent_id=agent_state.id,

View File

@@ -131,6 +131,10 @@ class ToolRulesSolver(BaseModel):
"""Check if all required-before-exit tools have been called."""
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
def get_requires_approval_tools(self, available_tools: set[ToolName]) -> list[ToolName]:
"""Get the list of tools that require approval."""
return [rule.tool_name for rule in self.requires_approval_tool_rules]
def get_uncalled_required_tools(self, available_tools: set[ToolName]) -> list[str]:
"""Get the list of required-before-exit tools that have not been called yet."""
if not self.required_before_exit_tool_rules:

View File

@@ -28,6 +28,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.schemas.letta_message import (
ApprovalRequestMessage,
AssistantMessage,
HiddenReasoningMessage,
LettaMessage,
@@ -59,7 +60,12 @@ class AnthropicStreamingInterface:
and detection of tool call events.
"""
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
def __init__(
self,
use_assistant_message: bool = False,
put_inner_thoughts_in_kwarg: bool = False,
requires_approval_tools: list = [],
):
self.json_parser: JSONParser = PydanticJSONParser()
self.use_assistant_message = use_assistant_message
@@ -90,6 +96,8 @@ class AnthropicStreamingInterface:
# Buffer to handle partial XML tags across chunks
self.partial_tag_buffer = ""
self.requires_approval_tools = requires_approval_tools
def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop"""
if not self.tool_call_name:
@@ -256,13 +264,15 @@ class AnthropicStreamingInterface:
self.inner_thoughts_complete = False
if not self.use_assistant_message:
# Buffer the initial tool call message instead of yielding immediately
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
)
self.tool_call_buffer.append(tool_call_msg)
# Only buffer the initial tool call message if it doesn't require approval
# For approval-required tools, we'll create the ApprovalRequestMessage later
if self.tool_call_name not in self.requires_approval_tools:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
self.anthropic_mode = EventMode.THINKING
# TODO: Can capture signature, etc.
@@ -353,11 +363,36 @@ class AnthropicStreamingInterface:
prev_message_type = reasoning_message.message_type
yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer
# Check if inner thoughts are complete - if so, flush the buffer or create approval message
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
self.inner_thoughts_complete = True
# Flush all buffered tool call messages
if len(self.tool_call_buffer) > 0:
# Check if this tool requires approval
if self.tool_call_name in self.requires_approval_tools:
# Create ApprovalRequestMessage directly (buffer should be empty)
if prev_message_type and prev_message_type != "approval_request_message":
message_index += 1
# Strip out inner thoughts from arguments
tool_call_args = self.accumulated_tool_call_args
if current_inner_thoughts:
tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "")
approval_msg = ApprovalRequestMessage(
id=self.letta_message_id,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
date=datetime.now(timezone.utc).isoformat(),
name=self.tool_call_name,
tool_call=ToolCallDelta(
name=self.tool_call_name,
tool_call_id=self.tool_call_id,
arguments=tool_call_args,
),
)
prev_message_type = approval_msg.message_type
yield approval_msg
elif len(self.tool_call_buffer) > 0:
# Flush buffered tool call messages for non-approval tools
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
@@ -371,9 +406,6 @@ class AnthropicStreamingInterface:
id=self.tool_call_buffer[0].id,
otid=Message.generate_otid_from_id(self.tool_call_buffer[0].id, message_index),
date=self.tool_call_buffer[0].date,
name=self.tool_call_buffer[0].name,
sender_id=self.tool_call_buffer[0].sender_id,
step_id=self.tool_call_buffer[0].step_id,
tool_call=ToolCallDelta(
name=self.tool_call_name,
tool_call_id=self.tool_call_id,
@@ -404,11 +436,18 @@ class AnthropicStreamingInterface:
yield assistant_msg
else:
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
if self.tool_call_name in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
else:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
if self.inner_thoughts_complete:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1

View File

@@ -11,6 +11,7 @@ from letta.llm_api.openai_client import is_openai_reasoning_model
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.schemas.letta_message import (
ApprovalRequestMessage,
AssistantMessage,
HiddenReasoningMessage,
LettaMessage,
@@ -43,6 +44,7 @@ class OpenAIStreamingInterface:
messages: Optional[list] = None,
tools: Optional[list] = None,
put_inner_thoughts_in_kwarg: bool = True,
requires_approval_tools: list = [],
):
self.use_assistant_message = use_assistant_message
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
@@ -86,6 +88,8 @@ class OpenAIStreamingInterface:
self.reasoning_messages = []
self.emitted_hidden_reasoning = False # Track if we've emitted hidden reasoning message
self.requires_approval_tools = requires_approval_tools
def get_reasoning_content(self) -> list[TextContent | OmittedReasoningContent]:
content = "".join(self.reasoning_messages).strip()
@@ -274,16 +278,28 @@ class OpenAIStreamingInterface:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
self.tool_call_name = str(self.function_name_buffer)
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
if self.tool_call_name in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
else:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -404,17 +420,30 @@ class OpenAIStreamingInterface:
combined_chunk = self.function_args_buffer + updates_main_json
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
if self.function_name_buffer in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
else:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
# clear buffer
@@ -424,17 +453,30 @@ class OpenAIStreamingInterface:
# If there's no buffer to clear, just output a new chunk with new data
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
if self.function_name_buffer in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
else:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
self.function_id_buffer = None

View File

@@ -265,7 +265,7 @@ class ApprovalRequestMessage(LettaMessage):
message_type: Literal[MessageType.approval_request_message] = Field(
default=MessageType.approval_request_message, description="The type of the message."
)
tool_call: ToolCall = Field(..., description="The tool call that has been requested by the llm to run")
tool_call: Union[ToolCall, ToolCallDelta] = Field(..., description="The tool call that has been requested by the llm to run")
class ApprovalResponseMessage(LettaMessage):

View File

@@ -45,7 +45,7 @@ dependencies = [
"llama-index>=0.12.2",
"llama-index-embeddings-openai>=0.3.1",
"anthropic>=0.49.0",
"letta-client==0.1.307",
"letta-client==0.1.314",
"openai>=1.99.9",
"opentelemetry-api==1.30.0",
"opentelemetry-sdk==1.30.0",

View File

@@ -51,6 +51,17 @@ def get_secret_code_tool(input_text: str) -> str:
return str(abs(hash(input_text)))
def accumulate_chunks(stream):
messages = []
prev_message_type = None
for chunk in stream:
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
messages.append(chunk)
prev_message_type = current_message_type
return messages
# ------------------------------
# Fixtures
# ------------------------------
@@ -185,15 +196,21 @@ def test_send_message_with_requires_approval_tool(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
stream_tokens=True,
)
assert response.messages is not None
assert len(response.messages) == 2
assert response.messages[0].message_type == "reasoning_message"
assert response.messages[1].message_type == "approval_request_message"
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 4
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "approval_request_message"
assert messages[2].message_type == "stop_reason"
assert messages[2].stop_reason == "requires_approval"
assert messages[3].message_type == "usage_statistics"
def test_send_message_after_turning_off_requires_approval(
@@ -201,13 +218,11 @@ def test_send_message_after_turning_off_requires_approval(
agent: AgentState,
approval_tool_fixture: Tool,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
response = client.agents.messages.create_stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True)
messages = accumulate_chunks(response)
approval_request_id = messages[0].id
client.agents.messages.create(
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
ApprovalCreate(
@@ -215,7 +230,9 @@ def test_send_message_after_turning_off_requires_approval(
approval_request_id=approval_request_id,
),
],
stream_tokens=True,
)
messages = accumulate_chunks(response)
client.agents.tools.modify_approval(
agent_id=agent.id,
@@ -223,19 +240,18 @@ def test_send_message_after_turning_off_requires_approval(
requires_approval=False,
)
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
response = client.agents.messages.create_stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True)
assert response.messages is not None
assert len(response.messages) == 3 or len(response.messages) == 5
assert response.messages[0].message_type == "reasoning_message"
assert response.messages[1].message_type == "tool_call_message"
assert response.messages[2].message_type == "tool_return_message"
if len(response.messages) == 5:
assert response.messages[3].message_type == "reasoning_message"
assert response.messages[4].message_type == "assistant_message"
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) == 5 or len(messages) == 7
assert messages[0].message_type == "reasoning_message"
assert messages[1].message_type == "tool_call_message"
assert messages[2].message_type == "tool_return_message"
if len(messages) > 5:
assert messages[3].message_type == "reasoning_message"
assert messages[4].message_type == "assistant_message"
# ------------------------------

10
uv.lock generated
View File

@@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.11, <3.14"
resolution-markers = [
"python_full_version >= '3.13'",
@@ -2608,7 +2608,7 @@ requires-dist = [
{ name = "langchain", marker = "extra == 'external-tools'", specifier = ">=0.3.7" },
{ name = "langchain-community", marker = "extra == 'desktop'", specifier = ">=0.3.7" },
{ name = "langchain-community", marker = "extra == 'external-tools'", specifier = ">=0.3.7" },
{ name = "letta-client", specifier = "==0.1.307" },
{ name = "letta-client", specifier = "==0.1.314" },
{ name = "llama-index", specifier = ">=0.12.2" },
{ name = "llama-index-embeddings-openai", specifier = ">=0.3.1" },
{ name = "locust", marker = "extra == 'desktop'", specifier = ">=2.31.5" },
@@ -2684,7 +2684,7 @@ provides-extras = ["postgres", "redis", "pinecone", "sqlite", "experimental", "s
[[package]]
name = "letta-client"
version = "0.1.307"
version = "0.1.314"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
@@ -2693,9 +2693,9 @@ dependencies = [
{ name = "pydantic-core" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/4d/ea/e6148fefa2f2925c49cd0569c9235c73f7699871d4b1be456f899774cdfd/letta_client-0.1.307.tar.gz", hash = "sha256:215b6d23cfc28a79812490ddb991bd979057ca28cd8491576873473b140086a7", size = 190679, upload-time = "2025-09-03T18:30:09.634Z" }
sdist = { url = "https://files.pythonhosted.org/packages/bb/b1/5f84118594b94bc0bb413ad7162458a2b83e469c21989681cd64cd5f279b/letta_client-0.1.314.tar.gz", hash = "sha256:bb8e4ed389faceaceadef1122444bb263517e8af3dcf21c63b51cf7828a897f2", size = 196791, upload-time = "2025-09-05T22:58:16.77Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/df/3f/cdc1d401037970d83c45a212d66a0319608c0e2f1da1536074e5de5353ed/letta_client-0.1.307-py3-none-any.whl", hash = "sha256:f07c3d58f2767e9ad9ecb11ca9227ba368e466ad05b48a7a49b5a1edd15b4cbc", size = 478428, upload-time = "2025-09-03T18:30:08.092Z" },
{ url = "https://files.pythonhosted.org/packages/5a/bb/962cba922b17d51ff158ee096301233c1583e66551cfada31d83bf0bf33e/letta_client-0.1.314-py3-none-any.whl", hash = "sha256:7a82f963188857f82952a0b44b1045d6d172908cd61d01cffa39fe21a2f20077", size = 492795, upload-time = "2025-09-05T22:58:14.814Z" },
]
[[package]]