feat: allow follow up user message for approvals LET-6272 (#6392)

* feat: allow follow up user message for approvals

* add tests
This commit is contained in:
cthomas
2025-11-25 17:02:57 -08:00
committed by Caren Thomas
parent 0653970533
commit db534836e4
3 changed files with 158 additions and 1 deletions

View File

@@ -144,7 +144,7 @@ async def _prepare_in_context_messages_no_persist_async(
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
# Check for approval-related message validation
if len(input_messages) == 1 and input_messages[0].type == "approval":
if input_messages[0].type == "approval":
# User is trying to send an approval response
if current_in_context_messages and current_in_context_messages[-1].role != "approval":
raise ValueError(
@@ -155,6 +155,11 @@ async def _prepare_in_context_messages_no_persist_async(
new_in_context_messages = create_approval_response_message_from_input(
agent_state=agent_state, input_message=input_messages[0], run_id=run_id
)
if len(input_messages) > 1:
follow_up_messages = await create_input_messages(
input_messages=input_messages[1:], agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
)
new_in_context_messages.extend(follow_up_messages)
else:
# User is trying to send a regular message
if current_in_context_messages and current_in_context_messages[-1].role == "approval":

View File

@@ -114,9 +114,17 @@ class LettaAgentV3(LettaAgentV2):
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor, run_id
)
follow_up_messages = []
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
follow_up_messages = input_messages_to_persist[1:]
input_messages_to_persist = [input_messages_to_persist[0]]
in_context_messages = in_context_messages + input_messages_to_persist
response_letta_messages = []
for i in range(max_steps):
if i == 1 and follow_up_messages:
input_messages_to_persist = follow_up_messages
follow_up_messages = []
response = self._step(
messages=in_context_messages + self.response_messages,
input_messages_to_persist=input_messages_to_persist,
@@ -237,8 +245,16 @@ class LettaAgentV3(LettaAgentV2):
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor, run_id
)
follow_up_messages = []
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
follow_up_messages = input_messages_to_persist[1:]
input_messages_to_persist = [input_messages_to_persist[0]]
in_context_messages = in_context_messages + input_messages_to_persist
for i in range(max_steps):
if i == 1 and follow_up_messages:
input_messages_to_persist = follow_up_messages
follow_up_messages = []
response = self._step(
messages=in_context_messages + self.response_messages,
input_messages_to_persist=input_messages_to_persist,

View File

@@ -585,6 +585,51 @@ def test_approve_and_follow_up_with_error(
assert messages[2].message_type == "tool_return_message"
def test_approve_with_user_message(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
client.agents.messages.create(
agent_id=agent.id,
messages=[
{
"type": "approval",
"approvals": [
{
"type": "approval",
"approve": True,
"tool_call_id": tool_call_id,
},
],
},
{
"type": "message",
"role": "user",
"content": "The secret code should not contain any special characters.",
},
],
)
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
assert messages[0].message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
# ------------------------------
# Deny Test Cases
# ------------------------------
@@ -800,6 +845,51 @@ def test_deny_and_follow_up_with_error(
assert messages[-1].message_type == "usage_statistics"
def test_deny_with_user_message(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
client.agents.messages.create(
agent_id=agent.id,
messages=[
{
"type": "approval",
"approvals": [
{
"type": "approval",
"approve": False,
"tool_call_id": tool_call_id,
},
],
},
{
"type": "message",
"role": "user",
"content": f"Actually, you don't need to call the tool, the secret code is {SECRET_CODE}",
},
],
)
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) > 2
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
# --------------------------------
# Client-Side Execution Test Cases
# --------------------------------
@@ -1020,6 +1110,52 @@ def test_client_side_tool_call_and_follow_up_with_error(
assert messages[-1].message_type == "usage_statistics"
def test_client_side_tool_call_with_user_message(
client: Letta,
agent: AgentState,
) -> None:
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
client.agents.messages.create(
agent_id=agent.id,
messages=[
{
"type": "approval",
"approvals": [
{
"type": "tool",
"tool_call_id": tool_call_id,
"tool_return": SECRET_CODE,
"status": "success",
},
],
},
{
"type": "message",
"role": "user",
"content": "The secret code should not contain any special characters.",
},
],
)
response = client.agents.messages.stream(
agent_id=agent.id,
messages=USER_MESSAGE_FOLLOW_UP,
stream_tokens=True,
)
messages = accumulate_chunks(response)
assert messages is not None
assert len(messages) > 2
assert messages[-2].message_type == "stop_reason"
assert messages[-1].message_type == "usage_statistics"
def test_parallel_tool_calling(
client: Letta,
agent: AgentState,