diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 3e5aed0a..5cddbb89 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -144,7 +144,7 @@ async def _prepare_in_context_messages_no_persist_async( current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) # Check for approval-related message validation - if len(input_messages) == 1 and input_messages[0].type == "approval": + if input_messages[0].type == "approval": # User is trying to send an approval response if current_in_context_messages and current_in_context_messages[-1].role != "approval": raise ValueError( @@ -155,6 +155,11 @@ async def _prepare_in_context_messages_no_persist_async( new_in_context_messages = create_approval_response_message_from_input( agent_state=agent_state, input_message=input_messages[0], run_id=run_id ) + if len(input_messages) > 1: + follow_up_messages = await create_input_messages( + input_messages=input_messages[1:], agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor + ) + new_in_context_messages.extend(follow_up_messages) else: # User is trying to send a regular message if current_in_context_messages and current_in_context_messages[-1].role == "approval": diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 58e1fd5e..09b35fa5 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -114,9 +114,17 @@ class LettaAgentV3(LettaAgentV2): in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( input_messages, self.agent_state, self.message_manager, self.actor, run_id ) + follow_up_messages = [] + if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval": + follow_up_messages = input_messages_to_persist[1:] + input_messages_to_persist = [input_messages_to_persist[0]] + in_context_messages = in_context_messages + input_messages_to_persist response_letta_messages = [] for i in range(max_steps): + if i == 1 and follow_up_messages: + input_messages_to_persist = follow_up_messages + follow_up_messages = [] response = self._step( messages=in_context_messages + self.response_messages, input_messages_to_persist=input_messages_to_persist, @@ -237,8 +245,16 @@ class LettaAgentV3(LettaAgentV2): in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( input_messages, self.agent_state, self.message_manager, self.actor, run_id ) + follow_up_messages = [] + if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval": + follow_up_messages = input_messages_to_persist[1:] + input_messages_to_persist = [input_messages_to_persist[0]] + in_context_messages = in_context_messages + input_messages_to_persist for i in range(max_steps): + if i == 1 and follow_up_messages: + input_messages_to_persist = follow_up_messages + follow_up_messages = [] response = self._step( messages=in_context_messages + self.response_messages, input_messages_to_persist=input_messages_to_persist, diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 9a007f61..0f02b9ef 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -585,6 +585,51 @@ def test_approve_and_follow_up_with_error( assert messages[2].message_type == "tool_return_message" +def test_approve_with_user_message( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[-1].tool_call.tool_call_id + + client.agents.messages.create( + agent_id=agent.id, + messages=[ + { + "type": "approval", + "approvals": [ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], + }, + { + "type": "message", + "role": "user", + "content": "The secret code should not contain any special characters.", + }, + ], + ) + + response = client.agents.messages.stream( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + assert messages[0].message_type in ["reasoning_message", "assistant_message", "tool_call_message"] + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" + + # ------------------------------ # Deny Test Cases # ------------------------------ @@ -800,6 +845,51 @@ def test_deny_and_follow_up_with_error( assert messages[-1].message_type == "usage_statistics" +def test_deny_with_user_message( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[-1].tool_call.tool_call_id + + client.agents.messages.create( + agent_id=agent.id, + messages=[ + { + "type": "approval", + "approvals": [ + { + "type": "approval", + "approve": False, + "tool_call_id": tool_call_id, + }, + ], + }, + { + "type": "message", + "role": "user", + "content": f"Actually, you don't need to call the tool, the secret code is {SECRET_CODE}", + }, + ], + ) + + response = client.agents.messages.stream( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" + + # -------------------------------- # Client-Side Execution Test Cases # -------------------------------- @@ -1020,6 +1110,52 @@ def test_client_side_tool_call_and_follow_up_with_error( assert messages[-1].message_type == "usage_statistics" +def test_client_side_tool_call_with_user_message( + client: Letta, + agent: AgentState, +) -> None: + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[-1].tool_call.tool_call_id + + client.agents.messages.create( + agent_id=agent.id, + messages=[ + { + "type": "approval", + "approvals": [ + { + "type": "tool", + "tool_call_id": tool_call_id, + "tool_return": SECRET_CODE, + "status": "success", + }, + ], + }, + { + "type": "message", + "role": "user", + "content": "The secret code should not contain any special characters.", + }, + ], + ) + + response = client.agents.messages.stream( + agent_id=agent.id, + messages=USER_MESSAGE_FOLLOW_UP, + stream_tokens=True, + ) + + messages = accumulate_chunks(response) + + assert messages is not None + assert len(messages) > 2 + assert messages[-2].message_type == "stop_reason" + assert messages[-1].message_type == "usage_statistics" + + def test_parallel_tool_calling( client: Letta, agent: AgentState,