feat: handle message persistence for approvals flows (#4338)
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
"""add approval fields to message model
|
||||
|
||||
Revision ID: f3bf00ef6118
|
||||
Revises: 54c76f7cabca
|
||||
Create Date: 2025-09-01 11:26:42.548009
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f3bf00ef6118"
|
||||
down_revision: Union[str, None] = "54c76f7cabca"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("messages", sa.Column("approval_request_id", sa.String(), nullable=True))
|
||||
op.add_column("messages", sa.Column("approve", sa.Boolean(), nullable=True))
|
||||
op.add_column("messages", sa.Column("denial_reason", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("messages", "denial_reason")
|
||||
op.drop_column("messages", "approve")
|
||||
op.drop_column("messages", "approval_request_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -13,7 +13,7 @@ from letta.schemas.message import Message, MessageCreate, MessageCreateBase
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_input_messages
|
||||
from letta.server.rest_api.utils import create_approval_response_message_from_input, create_input_messages
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -36,6 +36,8 @@ def _create_letta_response(
|
||||
response_messages = Message.to_letta_messages_from_list(
|
||||
messages=filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||
)
|
||||
# Filter approval response messages
|
||||
response_messages = [m for m in response_messages if m.message_type != "approval_response_message"]
|
||||
|
||||
# Apply message type filtering if specified
|
||||
if include_return_message_types is not None:
|
||||
@@ -161,7 +163,7 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
f"Invalid approval request ID. Expected '{current_in_context_messages[-1].id}' "
|
||||
f"but received '{input_messages[0].approval_request_id}'."
|
||||
)
|
||||
new_in_context_messages = []
|
||||
new_in_context_messages = create_approval_response_message_from_input(agent_state=agent_state, input_message=input_messages[0])
|
||||
else:
|
||||
# User is trying to send a regular message
|
||||
if current_in_context_messages[-1].role == "approval":
|
||||
|
||||
@@ -218,6 +218,7 @@ class LettaAgent(BaseAgent):
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
initial_messages = new_in_context_messages
|
||||
in_context_messages = current_in_context_messages
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
@@ -233,8 +234,8 @@ class LettaAgent(BaseAgent):
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
for i in range(max_steps):
|
||||
if not new_in_context_messages and current_in_context_messages[-1].role == "approval":
|
||||
approval_request_message = current_in_context_messages[-1]
|
||||
if in_context_messages[-1].role == "approval":
|
||||
approval_request_message = in_context_messages[-1]
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
approval_request_message.tool_calls[0],
|
||||
@@ -244,18 +245,19 @@ class LettaAgent(BaseAgent):
|
||||
usage,
|
||||
reasoning_content=approval_request_message.content,
|
||||
step_id=approval_request_message.step_id,
|
||||
initial_messages=[],
|
||||
initial_messages=initial_messages,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
step_metrics=step_metrics,
|
||||
run_id=self.current_run_id,
|
||||
is_approval=input_messages[0].approve,
|
||||
is_denial=not input_messages[0].approve,
|
||||
is_denial=input_messages[0].approve == False,
|
||||
denial_reason=input_messages[0].reason,
|
||||
)
|
||||
new_message_idx = 0
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
initial_messages = None
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
@@ -414,6 +416,7 @@ class LettaAgent(BaseAgent):
|
||||
letta_messages = Message.to_letta_messages_from_list(
|
||||
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||
)
|
||||
letta_messages = [m for m in letta_messages if m.message_type != "approval_response_message"]
|
||||
|
||||
for message in letta_messages:
|
||||
if include_return_message_types is None or message.message_type in include_return_message_types:
|
||||
@@ -557,6 +560,7 @@ class LettaAgent(BaseAgent):
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
initial_messages = new_in_context_messages
|
||||
in_context_messages = current_in_context_messages
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
@@ -572,8 +576,8 @@ class LettaAgent(BaseAgent):
|
||||
job_update_metadata = None
|
||||
usage = LettaUsageStatistics()
|
||||
for i in range(max_steps):
|
||||
if not new_in_context_messages and current_in_context_messages[-1].role == "approval":
|
||||
approval_request_message = current_in_context_messages[-1]
|
||||
if in_context_messages[-1].role == "approval":
|
||||
approval_request_message = in_context_messages[-1]
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
approval_request_message.tool_calls[0],
|
||||
@@ -583,18 +587,19 @@ class LettaAgent(BaseAgent):
|
||||
usage,
|
||||
reasoning_content=approval_request_message.content,
|
||||
step_id=approval_request_message.step_id,
|
||||
initial_messages=[],
|
||||
initial_messages=initial_messages,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
step_metrics=step_metrics,
|
||||
run_id=run_id or self.current_run_id,
|
||||
is_approval=input_messages[0].approve,
|
||||
is_denial=not input_messages[0].approve,
|
||||
is_denial=input_messages[0].approve == False,
|
||||
denial_reason=input_messages[0].reason,
|
||||
)
|
||||
new_message_idx = 0
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
initial_messages = None
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
else:
|
||||
# If dry run, build request data and return it without making LLM call
|
||||
if dry_run:
|
||||
@@ -897,6 +902,7 @@ class LettaAgent(BaseAgent):
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
initial_messages = new_in_context_messages
|
||||
in_context_messages = current_in_context_messages
|
||||
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
@@ -913,8 +919,8 @@ class LettaAgent(BaseAgent):
|
||||
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
for i in range(max_steps):
|
||||
if not new_in_context_messages and current_in_context_messages[-1].role == "approval":
|
||||
approval_request_message = current_in_context_messages[-1]
|
||||
if in_context_messages[-1].role == "approval":
|
||||
approval_request_message = in_context_messages[-1]
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
|
||||
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
|
||||
approval_request_message.tool_calls[0],
|
||||
@@ -924,18 +930,19 @@ class LettaAgent(BaseAgent):
|
||||
usage,
|
||||
reasoning_content=approval_request_message.content,
|
||||
step_id=approval_request_message.step_id,
|
||||
initial_messages=[],
|
||||
initial_messages=new_in_context_messages,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
step_metrics=step_metrics,
|
||||
run_id=self.current_run_id,
|
||||
is_approval=input_messages[0].approve,
|
||||
is_denial=not input_messages[0].approve,
|
||||
is_denial=input_messages[0].approve == False,
|
||||
denial_reason=input_messages[0].reason,
|
||||
)
|
||||
new_message_idx = 0
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
new_message_idx = len(initial_messages) if initial_messages else 0
|
||||
self.response_messages.extend(persisted_messages[new_message_idx:])
|
||||
new_in_context_messages.extend(persisted_messages[new_message_idx:])
|
||||
initial_messages = None
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
# yields tool response as this is handled from Letta and not the response from the LLM provider
|
||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||
@@ -1651,8 +1658,8 @@ class LettaAgent(BaseAgent):
|
||||
step_id=step_id,
|
||||
is_approval_response=True,
|
||||
)
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor)
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(messages_to_persist, actor=self.actor)
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
|
||||
# 1. Parse and validate the tool-call envelope
|
||||
|
||||
@@ -50,6 +50,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
is_err: Mapped[Optional[bool]] = mapped_column(
|
||||
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
|
||||
)
|
||||
approval_request_id: Mapped[Optional[str]] = mapped_column(
|
||||
nullable=True,
|
||||
doc="The id of the approval request if this message is associated with a tool call request.",
|
||||
)
|
||||
approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.")
|
||||
denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.")
|
||||
|
||||
# Monotonically increasing sequence for efficient/correct listing
|
||||
sequence_id: Mapped[int] = mapped_column(
|
||||
|
||||
@@ -21,6 +21,7 @@ from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
ApprovalRequestMessage,
|
||||
ApprovalResponseMessage,
|
||||
AssistantMessage,
|
||||
HiddenReasoningMessage,
|
||||
LettaMessage,
|
||||
@@ -199,6 +200,11 @@ class Message(BaseMessage):
|
||||
is_err: Optional[bool] = Field(
|
||||
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
|
||||
)
|
||||
approval_request_id: Optional[str] = Field(
|
||||
default=None, description="The id of the approval request if this message is associated with a tool call request."
|
||||
)
|
||||
approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.")
|
||||
denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.")
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@@ -301,8 +307,18 @@ class Message(BaseMessage):
|
||||
if self.tool_calls is not None:
|
||||
tool_calls = self._convert_tool_call_messages()
|
||||
assert len(tool_calls) == 1
|
||||
approval_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
|
||||
messages.append(approval_message)
|
||||
approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
|
||||
messages.append(approval_request_message)
|
||||
else:
|
||||
approval_response_message = ApprovalResponseMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
otid=self.otid,
|
||||
approve=self.approve,
|
||||
approval_request_id=self.approval_request_id,
|
||||
reason=self.denial_reason,
|
||||
)
|
||||
messages.append(approval_response_message)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {self.role}")
|
||||
|
||||
@@ -732,6 +748,8 @@ class Message(BaseMessage):
|
||||
use_developer_message: bool = False,
|
||||
) -> dict | None:
|
||||
"""Go from Message class to ChatCompletion message object"""
|
||||
if self.role == "approval" and self.tool_calls is None:
|
||||
return None
|
||||
|
||||
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
|
||||
# If we only have one content part and it's text, treat it as COT
|
||||
|
||||
@@ -25,10 +25,11 @@ from letta.log import get_logger
|
||||
from letta.otel.context import get_ctx_attributes
|
||||
from letta.otel.metric_registry import MetricRegistry
|
||||
from letta.otel.tracing import tracer
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.message import ApprovalCreate, Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
@@ -176,6 +177,19 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti
|
||||
return messages
|
||||
|
||||
|
||||
def create_approval_response_message_from_input(agent_state: AgentState, input_message: ApprovalCreate) -> List[Message]:
|
||||
return [
|
||||
Message(
|
||||
role=MessageRole.approval,
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
approval_request_id=input_message.approval_request_id,
|
||||
approve=input_message.approve,
|
||||
denial_reason=input_message.reason,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def create_approval_request_message_from_llm_response(
|
||||
agent_id: str,
|
||||
model: str,
|
||||
|
||||
@@ -214,6 +214,45 @@ def test_approve_tool_call_request(
|
||||
assert response.messages[2].message_type == "assistant_message"
|
||||
|
||||
|
||||
def test_approve_cursor_fetch(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
approval_request_id = response.messages[0].id
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
|
||||
assert len(messages) == 3
|
||||
assert messages[0].message_type == "user_message"
|
||||
assert messages[1].message_type == "reasoning_message"
|
||||
assert messages[2].message_type == "approval_request_message"
|
||||
assert messages[2].id == approval_request_id
|
||||
|
||||
last_message_cursor = approval_request_id
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=True,
|
||||
approval_request_id=approval_request_id,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
|
||||
assert len(messages) == 5
|
||||
assert messages[0].message_type == "approval_response_message"
|
||||
assert messages[1].message_type == "tool_return_message"
|
||||
assert messages[1].status == "success"
|
||||
assert messages[2].message_type == "user_message" # heartbeat
|
||||
assert messages[3].message_type == "reasoning_message"
|
||||
assert messages[4].message_type == "assistant_message"
|
||||
|
||||
|
||||
def test_deny_tool_call_request(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
@@ -244,3 +283,43 @@ def test_deny_tool_call_request(
|
||||
assert response.messages[1].message_type == "reasoning_message"
|
||||
assert response.messages[2].message_type == "assistant_message"
|
||||
assert SECRET_CODE in response.messages[2].content
|
||||
|
||||
|
||||
def test_deny_cursor_fetch(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
approval_request_id = response.messages[0].id
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
|
||||
assert len(messages) == 3
|
||||
assert messages[0].message_type == "user_message"
|
||||
assert messages[1].message_type == "reasoning_message"
|
||||
assert messages[2].message_type == "approval_request_message"
|
||||
assert messages[2].id == approval_request_id
|
||||
|
||||
last_message_cursor = approval_request_id
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
ApprovalCreate(
|
||||
approve=False,
|
||||
approval_request_id=approval_request_id,
|
||||
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
|
||||
assert len(messages) == 5
|
||||
assert messages[0].message_type == "approval_response_message"
|
||||
assert messages[1].message_type == "tool_return_message"
|
||||
assert messages[1].status == "error"
|
||||
assert messages[2].message_type == "user_message" # heartbeat
|
||||
assert messages[3].message_type == "reasoning_message"
|
||||
assert messages[4].message_type == "assistant_message"
|
||||
|
||||
Reference in New Issue
Block a user