feat: handle message persistence for approvals flows (#4338)

This commit is contained in:
cthomas
2025-09-01 14:10:02 -07:00
committed by GitHub
parent 7c88470705
commit 3f87fc34f2
7 changed files with 189 additions and 28 deletions

View File

@@ -0,0 +1,35 @@
"""add approval fields to message model
Revision ID: f3bf00ef6118
Revises: 54c76f7cabca
Create Date: 2025-09-01 11:26:42.548009
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "f3bf00ef6118"
down_revision: Union[str, None] = "54c76f7cabca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("approval_request_id", sa.String(), nullable=True))
op.add_column("messages", sa.Column("approve", sa.Boolean(), nullable=True))
op.add_column("messages", sa.Column("denial_reason", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("messages", "denial_reason")
op.drop_column("messages", "approve")
op.drop_column("messages", "approval_request_id")
# ### end Alembic commands ###

View File

@@ -13,7 +13,7 @@ from letta.schemas.message import Message, MessageCreate, MessageCreateBase
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_input_messages
from letta.server.rest_api.utils import create_approval_response_message_from_input, create_input_messages
from letta.services.message_manager import MessageManager
logger = get_logger(__name__)
@@ -36,6 +36,8 @@ def _create_letta_response(
response_messages = Message.to_letta_messages_from_list(
messages=filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
# Filter approval response messages
response_messages = [m for m in response_messages if m.message_type != "approval_response_message"]
# Apply message type filtering if specified
if include_return_message_types is not None:
@@ -161,7 +163,7 @@ async def _prepare_in_context_messages_no_persist_async(
f"Invalid approval request ID. Expected '{current_in_context_messages[-1].id}' "
f"but received '{input_messages[0].approval_request_id}'."
)
new_in_context_messages = []
new_in_context_messages = create_approval_response_message_from_input(agent_state=agent_state, input_message=input_messages[0])
else:
# User is trying to send a regular message
if current_in_context_messages[-1].role == "approval":

View File

@@ -218,6 +218,7 @@ class LettaAgent(BaseAgent):
input_messages, agent_state, self.message_manager, self.actor
)
initial_messages = new_in_context_messages
in_context_messages = current_in_context_messages
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
@@ -233,8 +234,8 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for i in range(max_steps):
if not new_in_context_messages and current_in_context_messages[-1].role == "approval":
approval_request_message = current_in_context_messages[-1]
if in_context_messages[-1].role == "approval":
approval_request_message = in_context_messages[-1]
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
approval_request_message.tool_calls[0],
@@ -244,18 +245,19 @@ class LettaAgent(BaseAgent):
usage,
reasoning_content=approval_request_message.content,
step_id=approval_request_message.step_id,
initial_messages=[],
initial_messages=initial_messages,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
run_id=self.current_run_id,
is_approval=input_messages[0].approve,
is_denial=not input_messages[0].approve,
is_denial=input_messages[0].approve == False,
denial_reason=input_messages[0].reason,
)
new_message_idx = 0
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
in_context_messages = current_in_context_messages + new_in_context_messages
# stream step
# TODO: improve TTFT
@@ -414,6 +416,7 @@ class LettaAgent(BaseAgent):
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
letta_messages = [m for m in letta_messages if m.message_type != "approval_response_message"]
for message in letta_messages:
if include_return_message_types is None or message.message_type in include_return_message_types:
@@ -557,6 +560,7 @@ class LettaAgent(BaseAgent):
input_messages, agent_state, self.message_manager, self.actor
)
initial_messages = new_in_context_messages
in_context_messages = current_in_context_messages
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
@@ -572,8 +576,8 @@ class LettaAgent(BaseAgent):
job_update_metadata = None
usage = LettaUsageStatistics()
for i in range(max_steps):
if not new_in_context_messages and current_in_context_messages[-1].role == "approval":
approval_request_message = current_in_context_messages[-1]
if in_context_messages[-1].role == "approval":
approval_request_message = in_context_messages[-1]
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
approval_request_message.tool_calls[0],
@@ -583,18 +587,19 @@ class LettaAgent(BaseAgent):
usage,
reasoning_content=approval_request_message.content,
step_id=approval_request_message.step_id,
initial_messages=[],
initial_messages=initial_messages,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
run_id=run_id or self.current_run_id,
is_approval=input_messages[0].approve,
is_denial=not input_messages[0].approve,
is_denial=input_messages[0].approve == False,
denial_reason=input_messages[0].reason,
)
new_message_idx = 0
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
in_context_messages = current_in_context_messages + new_in_context_messages
else:
# If dry run, build request data and return it without making LLM call
if dry_run:
@@ -897,6 +902,7 @@ class LettaAgent(BaseAgent):
input_messages, agent_state, self.message_manager, self.actor
)
initial_messages = new_in_context_messages
in_context_messages = current_in_context_messages
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
@@ -913,8 +919,8 @@ class LettaAgent(BaseAgent):
request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
for i in range(max_steps):
if not new_in_context_messages and current_in_context_messages[-1].role == "approval":
approval_request_message = current_in_context_messages[-1]
if in_context_messages[-1].role == "approval":
approval_request_message = in_context_messages[-1]
step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor)
persisted_messages, should_continue, stop_reason = await self._handle_ai_response(
approval_request_message.tool_calls[0],
@@ -924,18 +930,19 @@ class LettaAgent(BaseAgent):
usage,
reasoning_content=approval_request_message.content,
step_id=approval_request_message.step_id,
initial_messages=[],
initial_messages=new_in_context_messages,
is_final_step=(i == max_steps - 1),
step_metrics=step_metrics,
run_id=self.current_run_id,
is_approval=input_messages[0].approve,
is_denial=not input_messages[0].approve,
is_denial=input_messages[0].approve == False,
denial_reason=input_messages[0].reason,
)
new_message_idx = 0
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
new_message_idx = len(initial_messages) if initial_messages else 0
self.response_messages.extend(persisted_messages[new_message_idx:])
new_in_context_messages.extend(persisted_messages[new_message_idx:])
initial_messages = None
in_context_messages = current_in_context_messages + new_in_context_messages
# yields tool response as this is handled from Letta and not the response from the LLM provider
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
@@ -1651,8 +1658,8 @@ class LettaAgent(BaseAgent):
step_id=step_id,
is_approval_response=True,
)
persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor)
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(messages_to_persist, actor=self.actor)
return persisted_messages, continue_stepping, stop_reason
# 1. Parse and validate the tool-call envelope

View File

@@ -50,6 +50,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
is_err: Mapped[Optional[bool]] = mapped_column(
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
)
approval_request_id: Mapped[Optional[str]] = mapped_column(
nullable=True,
doc="The id of the approval request if this message is associated with a tool call request.",
)
approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.")
denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.")
# Monotonically increasing sequence for efficient/correct listing
sequence_id: Mapped[int] = mapped_column(

View File

@@ -21,6 +21,7 @@ from letta.schemas.enums import MessageRole
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
ApprovalRequestMessage,
ApprovalResponseMessage,
AssistantMessage,
HiddenReasoningMessage,
LettaMessage,
@@ -199,6 +200,11 @@ class Message(BaseMessage):
is_err: Optional[bool] = Field(
default=None, description="Whether this message is part of an error step. Used only for debugging purposes."
)
approval_request_id: Optional[str] = Field(
default=None, description="The id of the approval request if this message is associated with a tool call request."
)
approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.")
denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@@ -301,8 +307,18 @@ class Message(BaseMessage):
if self.tool_calls is not None:
tool_calls = self._convert_tool_call_messages()
assert len(tool_calls) == 1
approval_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
messages.append(approval_message)
approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
messages.append(approval_request_message)
else:
approval_response_message = ApprovalResponseMessage(
id=self.id,
date=self.created_at,
otid=self.otid,
approve=self.approve,
approval_request_id=self.approval_request_id,
reason=self.denial_reason,
)
messages.append(approval_response_message)
else:
raise ValueError(f"Unknown role: {self.role}")
@@ -732,6 +748,8 @@ class Message(BaseMessage):
use_developer_message: bool = False,
) -> dict | None:
"""Go from Message class to ChatCompletion message object"""
if self.role == "approval" and self.tool_calls is None:
return None
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
# If we only have one content part and it's text, treat it as COT

View File

@@ -25,10 +25,11 @@ from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.otel.tracing import tracer
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.message import ApprovalCreate, Message, MessageCreate, ToolReturn
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
@@ -176,6 +177,19 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti
return messages
def create_approval_response_message_from_input(agent_state: AgentState, input_message: ApprovalCreate) -> List[Message]:
return [
Message(
role=MessageRole.approval,
agent_id=agent_state.id,
model=agent_state.llm_config.model,
approval_request_id=input_message.approval_request_id,
approve=input_message.approve,
denial_reason=input_message.reason,
)
]
def create_approval_request_message_from_llm_response(
agent_id: str,
model: str,

View File

@@ -214,6 +214,45 @@ def test_approve_tool_call_request(
assert response.messages[2].message_type == "assistant_message"
def test_approve_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 3
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "approval_request_message"
assert messages[2].id == approval_request_id
last_message_cursor = approval_request_id
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=True,
approval_request_id=approval_request_id,
),
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 5
assert messages[0].message_type == "approval_response_message"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "success"
assert messages[2].message_type == "user_message" # heartbeat
assert messages[3].message_type == "reasoning_message"
assert messages[4].message_type == "assistant_message"
def test_deny_tool_call_request(
client: Letta,
agent: AgentState,
@@ -244,3 +283,43 @@ def test_deny_tool_call_request(
assert response.messages[1].message_type == "reasoning_message"
assert response.messages[2].message_type == "assistant_message"
assert SECRET_CODE in response.messages[2].content
def test_deny_cursor_fetch(
client: Letta,
agent: AgentState,
) -> None:
last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
approval_request_id = response.messages[0].id
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 3
assert messages[0].message_type == "user_message"
assert messages[1].message_type == "reasoning_message"
assert messages[2].message_type == "approval_request_message"
assert messages[2].id == approval_request_id
last_message_cursor = approval_request_id
client.agents.messages.create(
agent_id=agent.id,
messages=[
ApprovalCreate(
approve=False,
approval_request_id=approval_request_id,
reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}",
),
],
)
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor)
assert len(messages) == 5
assert messages[0].message_type == "approval_response_message"
assert messages[1].message_type == "tool_return_message"
assert messages[1].status == "error"
assert messages[2].message_type == "user_message" # heartbeat
assert messages[3].message_type == "reasoning_message"
assert messages[4].message_type == "assistant_message"