feat: allow follow up user message for approvals LET-6272 (#6392)
* feat: allow follow up user message for approvals * add tests
This commit is contained in:
@@ -144,7 +144,7 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
# Check for approval-related message validation
|
||||
if len(input_messages) == 1 and input_messages[0].type == "approval":
|
||||
if input_messages[0].type == "approval":
|
||||
# User is trying to send an approval response
|
||||
if current_in_context_messages and current_in_context_messages[-1].role != "approval":
|
||||
raise ValueError(
|
||||
@@ -155,6 +155,11 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
new_in_context_messages = create_approval_response_message_from_input(
|
||||
agent_state=agent_state, input_message=input_messages[0], run_id=run_id
|
||||
)
|
||||
if len(input_messages) > 1:
|
||||
follow_up_messages = await create_input_messages(
|
||||
input_messages=input_messages[1:], agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
|
||||
)
|
||||
new_in_context_messages.extend(follow_up_messages)
|
||||
else:
|
||||
# User is trying to send a regular message
|
||||
if current_in_context_messages and current_in_context_messages[-1].role == "approval":
|
||||
|
||||
@@ -114,9 +114,17 @@ class LettaAgentV3(LettaAgentV2):
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
)
|
||||
follow_up_messages = []
|
||||
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
|
||||
follow_up_messages = input_messages_to_persist[1:]
|
||||
input_messages_to_persist = [input_messages_to_persist[0]]
|
||||
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
response_letta_messages = []
|
||||
for i in range(max_steps):
|
||||
if i == 1 and follow_up_messages:
|
||||
input_messages_to_persist = follow_up_messages
|
||||
follow_up_messages = []
|
||||
response = self._step(
|
||||
messages=in_context_messages + self.response_messages,
|
||||
input_messages_to_persist=input_messages_to_persist,
|
||||
@@ -237,8 +245,16 @@ class LettaAgentV3(LettaAgentV2):
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
)
|
||||
follow_up_messages = []
|
||||
if len(input_messages_to_persist) > 1 and input_messages_to_persist[0].role == "approval":
|
||||
follow_up_messages = input_messages_to_persist[1:]
|
||||
input_messages_to_persist = [input_messages_to_persist[0]]
|
||||
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
for i in range(max_steps):
|
||||
if i == 1 and follow_up_messages:
|
||||
input_messages_to_persist = follow_up_messages
|
||||
follow_up_messages = []
|
||||
response = self._step(
|
||||
messages=in_context_messages + self.response_messages,
|
||||
input_messages_to_persist=input_messages_to_persist,
|
||||
|
||||
@@ -585,6 +585,51 @@ def test_approve_and_follow_up_with_error(
|
||||
assert messages[2].message_type == "tool_return_message"
|
||||
|
||||
|
||||
def test_approve_with_user_message(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
||||
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approvals": [
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": True,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": "The secret code should not contain any special characters.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
response = client.agents.messages.stream(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_FOLLOW_UP,
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert messages[0].message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
|
||||
assert messages[-2].message_type == "stop_reason"
|
||||
assert messages[-1].message_type == "usage_statistics"
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Deny Test Cases
|
||||
# ------------------------------
|
||||
@@ -800,6 +845,51 @@ def test_deny_and_follow_up_with_error(
|
||||
assert messages[-1].message_type == "usage_statistics"
|
||||
|
||||
|
||||
def test_deny_with_user_message(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
||||
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approvals": [
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": False,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": f"Actually, you don't need to call the tool, the secret code is {SECRET_CODE}",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
response = client.agents.messages.stream(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_FOLLOW_UP,
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) > 2
|
||||
assert messages[-2].message_type == "stop_reason"
|
||||
assert messages[-1].message_type == "usage_statistics"
|
||||
|
||||
|
||||
# --------------------------------
|
||||
# Client-Side Execution Test Cases
|
||||
# --------------------------------
|
||||
@@ -1020,6 +1110,52 @@ def test_client_side_tool_call_and_follow_up_with_error(
|
||||
assert messages[-1].message_type == "usage_statistics"
|
||||
|
||||
|
||||
def test_client_side_tool_call_with_user_message(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
||||
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approvals": [
|
||||
{
|
||||
"type": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_return": SECRET_CODE,
|
||||
"status": "success",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": "The secret code should not contain any special characters.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
response = client.agents.messages.stream(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_FOLLOW_UP,
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
assert messages is not None
|
||||
assert len(messages) > 2
|
||||
assert messages[-2].message_type == "stop_reason"
|
||||
assert messages[-1].message_type == "usage_statistics"
|
||||
|
||||
|
||||
def test_parallel_tool_calling(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
|
||||
Reference in New Issue
Block a user