diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 19adb37f..44e58ebc 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1022,6 +1022,7 @@ class LettaAgent(BaseAgent): interface = AnthropicStreamingInterface( use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + requires_approval_tools=tool_rules_solver.get_requires_approval_tools(valid_tool_names), ) elif agent_state.llm_config.model_endpoint_type == ProviderType.openai: interface = OpenAIStreamingInterface( @@ -1030,6 +1031,7 @@ class LettaAgent(BaseAgent): messages=current_in_context_messages + new_in_context_messages, tools=request_data.get("tools", []), put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + requires_approval_tools=tool_rules_solver.get_requires_approval_tools(valid_tool_names), ) else: raise ValueError(f"Streaming not supported for {agent_state.llm_config}") @@ -1174,12 +1176,13 @@ class LettaAgent(BaseAgent): ) step_progression = StepProgression.LOGGED_TRACE - # yields tool response as this is handled from Letta and not the response from the LLM provider - tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] - if not (use_assistant_message and tool_return.name == "send_message"): - # Apply message type filtering if specified - if include_return_message_types is None or tool_return.message_type in include_return_message_types: - yield f"data: {tool_return.model_dump_json()}\n\n" + if persisted_messages[-1].role != "approval": + # yields tool response as this is handled from Letta and not the response from the LLM provider + tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] + if not (use_assistant_message and tool_return.name == "send_message"): + # Apply message type filtering if specified + if include_return_message_types is None or tool_return.message_type in include_return_message_types: + yield f"data: {tool_return.model_dump_json()}\n\n" # TODO (cliandy): consolidate and expand with trace MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) @@ -1692,7 +1695,6 @@ class LettaAgent(BaseAgent): tool_call_id=tool_call_id, request_heartbeat=request_heartbeat, ) - if not is_approval and tool_rules_solver.is_requires_approval_tool(tool_call_name): approval_message = create_approval_request_message_from_llm_response( agent_id=agent_state.id, diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index e0f1f4d5..73384971 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -131,6 +131,10 @@ class ToolRulesSolver(BaseModel): """Check if all required-before-exit tools have been called.""" return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0 + def get_requires_approval_tools(self, available_tools: set[ToolName]) -> list[ToolName]: + """Get the list of tools that require approval.""" + return [rule.tool_name for rule in self.requires_approval_tool_rules] + def get_uncalled_required_tools(self, available_tools: set[ToolName]) -> list[str]: """Get the list of required-before-exit tools that have not been called yet.""" if not self.required_before_exit_tool_rules: diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 8f84c23a..e295fdd7 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -28,6 +28,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.log import get_logger from letta.schemas.letta_message import ( + ApprovalRequestMessage, AssistantMessage, HiddenReasoningMessage, LettaMessage, @@ -59,7 +60,12 @@ class AnthropicStreamingInterface: and detection of tool call events. """ - def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False): + def __init__( + self, + use_assistant_message: bool = False, + put_inner_thoughts_in_kwarg: bool = False, + requires_approval_tools: list = [], + ): self.json_parser: JSONParser = PydanticJSONParser() self.use_assistant_message = use_assistant_message @@ -90,6 +96,8 @@ class AnthropicStreamingInterface: # Buffer to handle partial XML tags across chunks self.partial_tag_buffer = "" + self.requires_approval_tools = requires_approval_tools + def get_tool_call_object(self) -> ToolCall: """Useful for agent loop""" if not self.tool_call_name: @@ -256,13 +264,15 @@ class AnthropicStreamingInterface: self.inner_thoughts_complete = False if not self.use_assistant_message: - # Buffer the initial tool call message instead of yielding immediately - tool_call_msg = ToolCallMessage( - id=self.letta_message_id, - tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), - date=datetime.now(timezone.utc).isoformat(), - ) - self.tool_call_buffer.append(tool_call_msg) + # Only buffer the initial tool call message if it doesn't require approval + # For approval-required tools, we'll create the ApprovalRequestMessage later + if self.tool_call_name not in self.requires_approval_tools: + tool_call_msg = ToolCallMessage( + id=self.letta_message_id, + tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), + date=datetime.now(timezone.utc).isoformat(), + ) + self.tool_call_buffer.append(tool_call_msg) elif isinstance(content, BetaThinkingBlock): self.anthropic_mode = EventMode.THINKING # TODO: Can capture signature, etc. @@ -353,11 +363,36 @@ class AnthropicStreamingInterface: prev_message_type = reasoning_message.message_type yield reasoning_message - # Check if inner thoughts are complete - if so, flush the buffer + # Check if inner thoughts are complete - if so, flush the buffer or create approval message if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args): self.inner_thoughts_complete = True - # Flush all buffered tool call messages - if len(self.tool_call_buffer) > 0: + + # Check if this tool requires approval + if self.tool_call_name in self.requires_approval_tools: + # Create ApprovalRequestMessage directly (buffer should be empty) + if prev_message_type and prev_message_type != "approval_request_message": + message_index += 1 + + # Strip out inner thoughts from arguments + tool_call_args = self.accumulated_tool_call_args + if current_inner_thoughts: + tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "") + + approval_msg = ApprovalRequestMessage( + id=self.letta_message_id, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + date=datetime.now(timezone.utc).isoformat(), + name=self.tool_call_name, + tool_call=ToolCallDelta( + name=self.tool_call_name, + tool_call_id=self.tool_call_id, + arguments=tool_call_args, + ), + ) + prev_message_type = approval_msg.message_type + yield approval_msg + elif len(self.tool_call_buffer) > 0: + # Flush buffered tool call messages for non-approval tools if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 @@ -371,9 +406,6 @@ class AnthropicStreamingInterface: id=self.tool_call_buffer[0].id, otid=Message.generate_otid_from_id(self.tool_call_buffer[0].id, message_index), date=self.tool_call_buffer[0].date, - name=self.tool_call_buffer[0].name, - sender_id=self.tool_call_buffer[0].sender_id, - step_id=self.tool_call_buffer[0].step_id, tool_call=ToolCallDelta( name=self.tool_call_name, tool_call_id=self.tool_call_id, @@ -404,11 +436,18 @@ class AnthropicStreamingInterface: yield assistant_msg else: # Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status - tool_call_msg = ToolCallMessage( - id=self.letta_message_id, - tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json), - date=datetime.now(timezone.utc).isoformat(), - ) + if self.tool_call_name in self.requires_approval_tools: + tool_call_msg = ApprovalRequestMessage( + id=self.letta_message_id, + tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json), + date=datetime.now(timezone.utc).isoformat(), + ) + else: + tool_call_msg = ToolCallMessage( + id=self.letta_message_id, + tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json), + date=datetime.now(timezone.utc).isoformat(), + ) if self.inner_thoughts_complete: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 10c6ed78..0ff2c6fb 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -11,6 +11,7 @@ from letta.llm_api.openai_client import is_openai_reasoning_model from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.log import get_logger from letta.schemas.letta_message import ( + ApprovalRequestMessage, AssistantMessage, HiddenReasoningMessage, LettaMessage, @@ -43,6 +44,7 @@ class OpenAIStreamingInterface: messages: Optional[list] = None, tools: Optional[list] = None, put_inner_thoughts_in_kwarg: bool = True, + requires_approval_tools: list = [], ): self.use_assistant_message = use_assistant_message self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL @@ -86,6 +88,8 @@ class OpenAIStreamingInterface: self.reasoning_messages = [] self.emitted_hidden_reasoning = False # Track if we've emitted hidden reasoning message + self.requires_approval_tools = requires_approval_tools + def get_reasoning_content(self) -> list[TextContent | OmittedReasoningContent]: content = "".join(self.reasoning_messages).strip() @@ -274,16 +278,28 @@ class OpenAIStreamingInterface: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 self.tool_call_name = str(self.function_name_buffer) - tool_call_msg = ToolCallMessage( - id=self.letta_message_id, - date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=self.function_name_buffer, - arguments=None, - tool_call_id=self.function_id_buffer, - ), - otid=Message.generate_otid_from_id(self.letta_message_id, message_index), - ) + if self.tool_call_name in self.requires_approval_tools: + tool_call_msg = ApprovalRequestMessage( + id=self.letta_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=self.function_name_buffer, + arguments=None, + tool_call_id=self.function_id_buffer, + ), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + ) + else: + tool_call_msg = ToolCallMessage( + id=self.letta_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=self.function_name_buffer, + arguments=None, + tool_call_id=self.function_id_buffer, + ), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + ) prev_message_type = tool_call_msg.message_type yield tool_call_msg @@ -404,17 +420,30 @@ class OpenAIStreamingInterface: combined_chunk = self.function_args_buffer + updates_main_json if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 - tool_call_msg = ToolCallMessage( - id=self.letta_message_id, - date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=self.function_name_buffer, - arguments=combined_chunk, - tool_call_id=self.function_id_buffer, - ), - # name=name, - otid=Message.generate_otid_from_id(self.letta_message_id, message_index), - ) + if self.function_name_buffer in self.requires_approval_tools: + tool_call_msg = ApprovalRequestMessage( + id=self.letta_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=self.function_name_buffer, + arguments=combined_chunk, + tool_call_id=self.function_id_buffer, + ), + # name=name, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + ) + else: + tool_call_msg = ToolCallMessage( + id=self.letta_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=self.function_name_buffer, + arguments=combined_chunk, + tool_call_id=self.function_id_buffer, + ), + # name=name, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + ) prev_message_type = tool_call_msg.message_type yield tool_call_msg # clear buffer @@ -424,17 +453,30 @@ class OpenAIStreamingInterface: # If there's no buffer to clear, just output a new chunk with new data if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 - tool_call_msg = ToolCallMessage( - id=self.letta_message_id, - date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=None, - arguments=updates_main_json, - tool_call_id=self.function_id_buffer, - ), - # name=name, - otid=Message.generate_otid_from_id(self.letta_message_id, message_index), - ) + if self.function_name_buffer in self.requires_approval_tools: + tool_call_msg = ApprovalRequestMessage( + id=self.letta_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=None, + arguments=updates_main_json, + tool_call_id=self.function_id_buffer, + ), + # name=name, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + ) + else: + tool_call_msg = ToolCallMessage( + id=self.letta_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=None, + arguments=updates_main_json, + tool_call_id=self.function_id_buffer, + ), + # name=name, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + ) prev_message_type = tool_call_msg.message_type yield tool_call_msg self.function_id_buffer = None diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index 0e3859fd..1d1904e0 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -265,7 +265,7 @@ class ApprovalRequestMessage(LettaMessage): message_type: Literal[MessageType.approval_request_message] = Field( default=MessageType.approval_request_message, description="The type of the message." ) - tool_call: ToolCall = Field(..., description="The tool call that has been requested by the llm to run") + tool_call: Union[ToolCall, ToolCallDelta] = Field(..., description="The tool call that has been requested by the llm to run") class ApprovalResponseMessage(LettaMessage): diff --git a/pyproject.toml b/pyproject.toml index f030fe4b..280eada4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ "llama-index>=0.12.2", "llama-index-embeddings-openai>=0.3.1", "anthropic>=0.49.0", - "letta-client==0.1.307", + "letta-client==0.1.314", "openai>=1.99.9", "opentelemetry-api==1.30.0", "opentelemetry-sdk==1.30.0", diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 61061fae..ef16be99 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -51,6 +51,17 @@ def get_secret_code_tool(input_text: str) -> str: return str(abs(hash(input_text))) +def accumulate_chunks(stream): + messages = [] + prev_message_type = None + for chunk in stream: + current_message_type = chunk.message_type + if prev_message_type != current_message_type: + messages.append(chunk) + prev_message_type = current_message_type + return messages + + # ------------------------------ # Fixtures # ------------------------------ @@ -185,15 +196,21 @@ def test_send_message_with_requires_approval_tool( client: Letta, agent: AgentState, ) -> None: - response = client.agents.messages.create( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, + stream_tokens=True, ) - assert response.messages is not None - assert len(response.messages) == 2 - assert response.messages[0].message_type == "reasoning_message" - assert response.messages[1].message_type == "approval_request_message" + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) == 4 + assert messages[0].message_type == "reasoning_message" + assert messages[1].message_type == "approval_request_message" + assert messages[2].message_type == "stop_reason" + assert messages[2].stop_reason == "requires_approval" + assert messages[3].message_type == "usage_statistics" def test_send_message_after_turning_off_requires_approval( @@ -201,13 +218,11 @@ def test_send_message_after_turning_off_requires_approval( agent: AgentState, approval_tool_fixture: Tool, ) -> None: - response = client.agents.messages.create( - agent_id=agent.id, - messages=USER_MESSAGE_TEST_APPROVAL, - ) - approval_request_id = response.messages[0].id + response = client.agents.messages.create_stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True) + messages = accumulate_chunks(response) + approval_request_id = messages[0].id - client.agents.messages.create( + response = client.agents.messages.create_stream( agent_id=agent.id, messages=[ ApprovalCreate( @@ -215,7 +230,9 @@ def test_send_message_after_turning_off_requires_approval( approval_request_id=approval_request_id, ), ], + stream_tokens=True, ) + messages = accumulate_chunks(response) client.agents.tools.modify_approval( agent_id=agent.id, @@ -223,19 +240,18 @@ def test_send_message_after_turning_off_requires_approval( requires_approval=False, ) - response = client.agents.messages.create( - agent_id=agent.id, - messages=USER_MESSAGE_TEST_APPROVAL, - ) + response = client.agents.messages.create_stream(agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, stream_tokens=True) - assert response.messages is not None - assert len(response.messages) == 3 or len(response.messages) == 5 - assert response.messages[0].message_type == "reasoning_message" - assert response.messages[1].message_type == "tool_call_message" - assert response.messages[2].message_type == "tool_return_message" - if len(response.messages) == 5: - assert response.messages[3].message_type == "reasoning_message" - assert response.messages[4].message_type == "assistant_message" + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) == 5 or len(messages) == 7 + assert messages[0].message_type == "reasoning_message" + assert messages[1].message_type == "tool_call_message" + assert messages[2].message_type == "tool_return_message" + if len(messages) > 5: + assert messages[3].message_type == "reasoning_message" + assert messages[4].message_type == "assistant_message" # ------------------------------ diff --git a/uv.lock b/uv.lock index 85b5a04d..763bf2bc 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.14" resolution-markers = [ "python_full_version >= '3.13'", @@ -2608,7 +2608,7 @@ requires-dist = [ { name = "langchain", marker = "extra == 'external-tools'", specifier = ">=0.3.7" }, { name = "langchain-community", marker = "extra == 'desktop'", specifier = ">=0.3.7" }, { name = "langchain-community", marker = "extra == 'external-tools'", specifier = ">=0.3.7" }, - { name = "letta-client", specifier = "==0.1.307" }, + { name = "letta-client", specifier = "==0.1.314" }, { name = "llama-index", specifier = ">=0.12.2" }, { name = "llama-index-embeddings-openai", specifier = ">=0.3.1" }, { name = "locust", marker = "extra == 'desktop'", specifier = ">=2.31.5" }, @@ -2684,7 +2684,7 @@ provides-extras = ["postgres", "redis", "pinecone", "sqlite", "experimental", "s [[package]] name = "letta-client" -version = "0.1.307" +version = "0.1.314" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -2693,9 +2693,9 @@ dependencies = [ { name = "pydantic-core" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4d/ea/e6148fefa2f2925c49cd0569c9235c73f7699871d4b1be456f899774cdfd/letta_client-0.1.307.tar.gz", hash = "sha256:215b6d23cfc28a79812490ddb991bd979057ca28cd8491576873473b140086a7", size = 190679, upload-time = "2025-09-03T18:30:09.634Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/b1/5f84118594b94bc0bb413ad7162458a2b83e469c21989681cd64cd5f279b/letta_client-0.1.314.tar.gz", hash = "sha256:bb8e4ed389faceaceadef1122444bb263517e8af3dcf21c63b51cf7828a897f2", size = 196791, upload-time = "2025-09-05T22:58:16.77Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/df/3f/cdc1d401037970d83c45a212d66a0319608c0e2f1da1536074e5de5353ed/letta_client-0.1.307-py3-none-any.whl", hash = "sha256:f07c3d58f2767e9ad9ecb11ca9227ba368e466ad05b48a7a49b5a1edd15b4cbc", size = 478428, upload-time = "2025-09-03T18:30:08.092Z" }, + { url = "https://files.pythonhosted.org/packages/5a/bb/962cba922b17d51ff158ee096301233c1583e66551cfada31d83bf0bf33e/letta_client-0.1.314-py3-none-any.whl", hash = "sha256:7a82f963188857f82952a0b44b1045d6d172908cd61d01cffa39fe21a2f20077", size = 492795, upload-time = "2025-09-05T22:58:14.814Z" }, ] [[package]]