diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index fa093129..d73f26f4 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -27,7 +27,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole -from letta.schemas.letta_message import LettaMessage, MessageType +from letta.schemas.letta_message import ApprovalReturn, LettaMessage, MessageType from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType @@ -318,19 +318,19 @@ class LettaAgentV3(LettaAgentV2): pending_tool_calls = { backfill_tool_call_id if a.tool_call_id.startswith("message-") else a.tool_call_id: a for a in approval_response.approvals - if a.type == "approval" and a.approve + if isinstance(a, ApprovalReturn) and a.approve } tool_calls = [t for t in approval_request.tool_calls if t.id in pending_tool_calls] # Get tool calls that were denied - denies = {d.tool_call_id: d for d in approval_response.approvals if d.type == "approval" and not d.approve} + denies = {d.tool_call_id: d for d in approval_response.approvals if isinstance(d, ApprovalReturn) and not d.approve} tool_call_denials = [ ToolCallDenial(**t.model_dump(), reason=denies.get(t.id).reason) for t in approval_request.tool_calls if t.id in denies ] # Get tool calls that were executed client side if approval_response.approvals: - tool_returns = [r for r in approval_response.approvals if r.type == "tool"] + tool_returns = [r for r in approval_response.approvals if isinstance(r, ToolReturn)] step_id = approval_request.step_id step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor) diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 7bf58059..5d905a62 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -220,7 +220,7 @@ def test_send_approval_message_with_incorrect_request_id(client, agent): # ------------------------------ -def test_send_message_with_requires_approval_tool( +def test_invoke_approval_request( client: Letta, agent: AgentState, ) -> None: @@ -246,7 +246,45 @@ def test_send_message_with_requires_approval_tool( assert messages[4].message_type == "usage_statistics" -def test_send_message_after_turning_off_requires_approval( +def test_invoke_approval_request_with_context_check( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[2].tool_call.tool_call_id + + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], + ), + ], + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + try: + client.agents.context.retrieve(agent_id=agent.id) + except Exception as e: + if len(messages) > 4: + raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") + raise e + + +def test_invoke_tool_after_turning_off_requires_approval( client: Letta, agent: AgentState, approval_tool_fixture: Tool, @@ -434,6 +472,44 @@ def test_approve_cursor_fetch( assert messages[3].message_type == "assistant_message" +def test_approve_with_context_check( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[2].tool_call.tool_call_id + + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=False, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + approvals=[ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], + ), + ], + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + try: + client.agents.context.retrieve(agent_id=agent.id) + except Exception as e: + if len(messages) > 4: + raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") + raise e + + def test_approve_and_follow_up( client: Letta, agent: AgentState, @@ -643,6 +719,46 @@ def test_deny_cursor_fetch( assert messages[3].message_type == "assistant_message" +def test_deny_with_context_check( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[2].tool_call.tool_call_id + + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy + approvals=[ + { + "type": "approval", + "approve": False, + "tool_call_id": tool_call_id, + "reason": "Cancelled by user. Instead of responding, wait for next user input before replying.", + }, + ], + ), + ], + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + try: + client.agents.context.retrieve(agent_id=agent.id) + except Exception as e: + if len(messages) > 4: + raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") + raise e + + def test_deny_and_follow_up( client: Letta, agent: AgentState, @@ -688,46 +804,6 @@ def test_deny_and_follow_up( assert messages[3].message_type == "usage_statistics" -def test_deny_with_context_check( - client: Letta, - agent: AgentState, -) -> None: - response = client.agents.messages.create( - agent_id=agent.id, - messages=USER_MESSAGE_TEST_APPROVAL, - ) - tool_call_id = response.messages[2].tool_call.tool_call_id - - response = client.agents.messages.create_stream( - agent_id=agent.id, - messages=[ - ApprovalCreate( - approve=True, # legacy (passing incorrect value to ensure it is overridden) - approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) - reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy - approvals=[ - { - "type": "approval", - "approve": False, - "tool_call_id": tool_call_id, - "reason": "Cancelled by user. Instead of responding, wait for next user input before replying.", - }, - ], - ), - ], - stream_tokens=True, - ) - - messages = accumulate_chunks(response) - - try: - client.agents.context.retrieve(agent_id=agent.id) - except Exception as e: - if len(messages) > 4: - raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") - raise e - - def test_deny_and_follow_up_with_error( client: Letta, agent: AgentState, @@ -781,3 +857,251 @@ def test_deny_and_follow_up_with_error( assert messages[1].message_type == "assistant_message" assert messages[2].message_type == "stop_reason" assert messages[3].message_type == "usage_statistics" + + +# -------------------------------- +# Client-Side Execution Test Cases +# -------------------------------- + + +def test_client_side_tool_call_request( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[2].tool_call.tool_call_id + + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "tool", + "tool_call_id": tool_call_id, + "tool_return": SECRET_CODE, + "status": "success", + }, + ], + ), + ], + ) + + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) == 5 + assert messages[0].message_type == "tool_return_message" + assert messages[0].tool_call_id == tool_call_id + assert messages[0].status == "success" + assert messages[0].tool_return == SECRET_CODE + assert messages[1].message_type == "reasoning_message" + assert messages[2].message_type == "assistant_message" + assert SECRET_CODE in messages[2].content + assert messages[3].message_type == "stop_reason" + assert messages[4].message_type == "usage_statistics" + + +def test_client_side_tool_call_cursor_fetch( + client: Letta, + agent: AgentState, +) -> None: + last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + last_message_id = response.messages[0].id + tool_call_id = response.messages[2].tool_call.tool_call_id + + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) + assert len(messages) == 4 + assert messages[0].message_type == "user_message" + assert messages[1].message_type == "reasoning_message" + assert messages[2].message_type == "assistant_message" + assert messages[3].message_type == "approval_request_message" + assert messages[3].tool_call.tool_call_id == tool_call_id + # Ensure no request_heartbeat on approval request + # import json as _json + + # _args = _json.loads(messages[2].tool_call.arguments) + # assert "request_heartbeat" not in _args + + client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "tool", + "tool_call_id": tool_call_id, + "tool_return": SECRET_CODE, + "status": "success", + }, + ], + ), + ], + ) + + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) + assert len(messages) == 4 + assert messages[0].message_type == "approval_response_message" + assert messages[0].approvals[0]["type"] == "tool" + assert messages[0].approvals[0]["tool_call_id"] == tool_call_id + assert messages[0].approvals[0]["tool_return"] == SECRET_CODE + assert messages[0].approvals[0]["status"] == "success" + assert messages[1].message_type == "tool_return_message" + assert messages[1].status == "success" + assert messages[1].tool_call_id == tool_call_id + assert messages[1].tool_return == SECRET_CODE + assert messages[2].message_type == "reasoning_message" + assert messages[3].message_type == "assistant_message" + + +def test_client_side_tool_call_with_context_check( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[2].tool_call.tool_call_id + + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason="Cancelled by user. Instead of responding, wait for next user input before replying.", # legacy + approvals=[ + { + "type": "tool", + "tool_call_id": tool_call_id, + "tool_return": SECRET_CODE, + "status": "success", + }, + ], + ), + ], + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + try: + client.agents.context.retrieve(agent_id=agent.id) + except Exception as e: + if len(messages) > 4: + raise ValueError("Model did not respond with only reasoning content, please rerun test to repro edge case.") + raise e + + +def test_client_side_tool_call_and_follow_up( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[2].tool_call.tool_call_id + + client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "tool", + "tool_call_id": tool_call_id, + "tool_return": SECRET_CODE, + "status": "success", + }, + ], + ), + ], + ) + + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) == 4 + assert messages[0].message_type == "reasoning_message" + assert messages[1].message_type == "assistant_message" + assert messages[2].message_type == "stop_reason" + assert messages[3].message_type == "usage_statistics" + + +def test_client_side_tool_call_and_follow_up_with_error( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[2].tool_call.tool_call_id + + # Mock the streaming adapter to return llm invocation failure on the follow up turn + with patch.object(SimpleLLMStreamAdapter, "invoke_llm", side_effect=ValueError("TEST: Mocked error")): + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, # legacy (passing incorrect value to ensure it is overridden) + approval_request_id=FAKE_REQUEST_ID, # legacy (passing incorrect value to ensure it is overridden) + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", # legacy + approvals=[ + { + "type": "tool", + "tool_call_id": tool_call_id, + "tool_return": SECRET_CODE, + "status": "success", + }, + ], + ), + ], + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + stop_reason_message = [m for m in messages if m.message_type == "stop_reason"][0] + assert stop_reason_message + assert stop_reason_message.stop_reason == "invalid_llm_response" + + # Ensure that agent is not bricked + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) == 4 + assert messages[0].message_type == "reasoning_message" + assert messages[1].message_type == "assistant_message" + assert messages[2].message_type == "stop_reason" + assert messages[3].message_type == "usage_statistics"