diff --git a/alembic/versions/f3bf00ef6118_add_approval_fields_to_message_model.py b/alembic/versions/f3bf00ef6118_add_approval_fields_to_message_model.py new file mode 100644 index 00000000..e7de5b3a --- /dev/null +++ b/alembic/versions/f3bf00ef6118_add_approval_fields_to_message_model.py @@ -0,0 +1,35 @@ +"""add approval fields to message model + +Revision ID: f3bf00ef6118 +Revises: 54c76f7cabca +Create Date: 2025-09-01 11:26:42.548009 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f3bf00ef6118" +down_revision: Union[str, None] = "54c76f7cabca" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("messages", sa.Column("approval_request_id", sa.String(), nullable=True)) + op.add_column("messages", sa.Column("approve", sa.Boolean(), nullable=True)) + op.add_column("messages", sa.Column("denial_reason", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("messages", "denial_reason") + op.drop_column("messages", "approve") + op.drop_column("messages", "approval_request_id") + # ### end Alembic commands ### diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 0d5da5fc..608199e5 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -13,7 +13,7 @@ from letta.schemas.message import Message, MessageCreate, MessageCreateBase from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User -from letta.server.rest_api.utils import create_input_messages +from letta.server.rest_api.utils import create_approval_response_message_from_input, create_input_messages from letta.services.message_manager import MessageManager logger = get_logger(__name__) @@ -36,6 +36,8 @@ def _create_letta_response( response_messages = Message.to_letta_messages_from_list( messages=filter_user_messages, use_assistant_message=use_assistant_message, reverse=False ) + # Filter approval response messages + response_messages = [m for m in response_messages if m.message_type != "approval_response_message"] # Apply message type filtering if specified if include_return_message_types is not None: @@ -161,7 +163,7 @@ async def _prepare_in_context_messages_no_persist_async( f"Invalid approval request ID. Expected '{current_in_context_messages[-1].id}' " f"but received '{input_messages[0].approval_request_id}'." ) - new_in_context_messages = [] + new_in_context_messages = create_approval_response_message_from_input(agent_state=agent_state, input_message=input_messages[0]) else: # User is trying to send a regular message if current_in_context_messages[-1].role == "approval": diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 4362214c..49a50cb1 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -218,6 +218,7 @@ class LettaAgent(BaseAgent): input_messages, agent_state, self.message_manager, self.actor ) initial_messages = new_in_context_messages + in_context_messages = current_in_context_messages tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( provider_type=agent_state.llm_config.model_endpoint_type, @@ -233,8 +234,8 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) for i in range(max_steps): - if not new_in_context_messages and current_in_context_messages[-1].role == "approval": - approval_request_message = current_in_context_messages[-1] + if in_context_messages[-1].role == "approval": + approval_request_message = in_context_messages[-1] step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor) persisted_messages, should_continue, stop_reason = await self._handle_ai_response( approval_request_message.tool_calls[0], @@ -244,18 +245,19 @@ class LettaAgent(BaseAgent): usage, reasoning_content=approval_request_message.content, step_id=approval_request_message.step_id, - initial_messages=[], + initial_messages=initial_messages, is_final_step=(i == max_steps - 1), step_metrics=step_metrics, run_id=self.current_run_id, is_approval=input_messages[0].approve, - is_denial=not input_messages[0].approve, + is_denial=input_messages[0].approve == False, denial_reason=input_messages[0].reason, ) - new_message_idx = 0 - self.response_messages.extend(persisted_messages) - new_in_context_messages.extend(persisted_messages) + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) initial_messages = None + in_context_messages = current_in_context_messages + new_in_context_messages # stream step # TODO: improve TTFT @@ -414,6 +416,7 @@ class LettaAgent(BaseAgent): letta_messages = Message.to_letta_messages_from_list( filter_user_messages, use_assistant_message=use_assistant_message, reverse=False ) + letta_messages = [m for m in letta_messages if m.message_type != "approval_response_message"] for message in letta_messages: if include_return_message_types is None or message.message_type in include_return_message_types: @@ -557,6 +560,7 @@ class LettaAgent(BaseAgent): input_messages, agent_state, self.message_manager, self.actor ) initial_messages = new_in_context_messages + in_context_messages = current_in_context_messages tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( provider_type=agent_state.llm_config.model_endpoint_type, @@ -572,8 +576,8 @@ class LettaAgent(BaseAgent): job_update_metadata = None usage = LettaUsageStatistics() for i in range(max_steps): - if not new_in_context_messages and current_in_context_messages[-1].role == "approval": - approval_request_message = current_in_context_messages[-1] + if in_context_messages[-1].role == "approval": + approval_request_message = in_context_messages[-1] step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor) persisted_messages, should_continue, stop_reason = await self._handle_ai_response( approval_request_message.tool_calls[0], @@ -583,18 +587,19 @@ class LettaAgent(BaseAgent): usage, reasoning_content=approval_request_message.content, step_id=approval_request_message.step_id, - initial_messages=[], + initial_messages=initial_messages, is_final_step=(i == max_steps - 1), step_metrics=step_metrics, run_id=run_id or self.current_run_id, is_approval=input_messages[0].approve, - is_denial=not input_messages[0].approve, + is_denial=input_messages[0].approve == False, denial_reason=input_messages[0].reason, ) - new_message_idx = 0 - self.response_messages.extend(persisted_messages) - new_in_context_messages.extend(persisted_messages) + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) initial_messages = None + in_context_messages = current_in_context_messages + new_in_context_messages else: # If dry run, build request data and return it without making LLM call if dry_run: @@ -897,6 +902,7 @@ class LettaAgent(BaseAgent): input_messages, agent_state, self.message_manager, self.actor ) initial_messages = new_in_context_messages + in_context_messages = current_in_context_messages tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( @@ -913,8 +919,8 @@ class LettaAgent(BaseAgent): request_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) for i in range(max_steps): - if not new_in_context_messages and current_in_context_messages[-1].role == "approval": - approval_request_message = current_in_context_messages[-1] + if in_context_messages[-1].role == "approval": + approval_request_message = in_context_messages[-1] step_metrics = await self.step_manager.get_step_metrics_async(step_id=approval_request_message.step_id, actor=self.actor) persisted_messages, should_continue, stop_reason = await self._handle_ai_response( approval_request_message.tool_calls[0], @@ -924,18 +930,19 @@ class LettaAgent(BaseAgent): usage, reasoning_content=approval_request_message.content, step_id=approval_request_message.step_id, - initial_messages=[], + initial_messages=new_in_context_messages, is_final_step=(i == max_steps - 1), step_metrics=step_metrics, run_id=self.current_run_id, is_approval=input_messages[0].approve, - is_denial=not input_messages[0].approve, + is_denial=input_messages[0].approve == False, denial_reason=input_messages[0].reason, ) - new_message_idx = 0 - self.response_messages.extend(persisted_messages) - new_in_context_messages.extend(persisted_messages) + new_message_idx = len(initial_messages) if initial_messages else 0 + self.response_messages.extend(persisted_messages[new_message_idx:]) + new_in_context_messages.extend(persisted_messages[new_message_idx:]) initial_messages = None + in_context_messages = current_in_context_messages + new_in_context_messages # yields tool response as this is handled from Letta and not the response from the LLM provider tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] @@ -1651,8 +1658,8 @@ class LettaAgent(BaseAgent): step_id=step_id, is_approval_response=True, ) - - persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor) + messages_to_persist = (initial_messages or []) + tool_call_messages + persisted_messages = await self.message_manager.create_many_messages_async(messages_to_persist, actor=self.actor) return persisted_messages, continue_stepping, stop_reason # 1. Parse and validate the tool-call envelope diff --git a/letta/orm/message.py b/letta/orm/message.py index c331c593..76b9a8c0 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -50,6 +50,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): is_err: Mapped[Optional[bool]] = mapped_column( nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes." ) + approval_request_id: Mapped[Optional[str]] = mapped_column( + nullable=True, + doc="The id of the approval request if this message is associated with a tool call request.", + ) + approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.") + denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.") # Monotonically increasing sequence for efficient/correct listing sequence_id: Mapped[int] = mapped_column( diff --git a/letta/schemas/message.py b/letta/schemas/message.py index ee5430d6..adc68909 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -21,6 +21,7 @@ from letta.schemas.enums import MessageRole from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import ( ApprovalRequestMessage, + ApprovalResponseMessage, AssistantMessage, HiddenReasoningMessage, LettaMessage, @@ -199,6 +200,11 @@ class Message(BaseMessage): is_err: Optional[bool] = Field( default=None, description="Whether this message is part of an error step. Used only for debugging purposes." ) + approval_request_id: Optional[str] = Field( + default=None, description="The id of the approval request if this message is associated with a tool call request." + ) + approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.") + denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.") # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") @@ -301,8 +307,18 @@ class Message(BaseMessage): if self.tool_calls is not None: tool_calls = self._convert_tool_call_messages() assert len(tool_calls) == 1 - approval_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"})) - messages.append(approval_message) + approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"})) + messages.append(approval_request_message) + else: + approval_response_message = ApprovalResponseMessage( + id=self.id, + date=self.created_at, + otid=self.otid, + approve=self.approve, + approval_request_id=self.approval_request_id, + reason=self.denial_reason, + ) + messages.append(approval_response_message) else: raise ValueError(f"Unknown role: {self.role}") @@ -732,6 +748,8 @@ class Message(BaseMessage): use_developer_message: bool = False, ) -> dict | None: """Go from Message class to ChatCompletion message object""" + if self.role == "approval" and self.tool_calls is None: + return None # TODO change to pydantic casting, eg `return SystemMessageModel(self)` # If we only have one content part and it's text, treat it as COT diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index c72b2513..5dbdc0de 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -25,10 +25,11 @@ from letta.log import get_logger from letta.otel.context import get_ctx_attributes from letta.otel.metric_registry import MetricRegistry from letta.otel.tracing import tracer +from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message, MessageCreate, ToolReturn +from letta.schemas.message import ApprovalCreate, Message, MessageCreate, ToolReturn from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -176,6 +177,19 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti return messages +def create_approval_response_message_from_input(agent_state: AgentState, input_message: ApprovalCreate) -> List[Message]: + return [ + Message( + role=MessageRole.approval, + agent_id=agent_state.id, + model=agent_state.llm_config.model, + approval_request_id=input_message.approval_request_id, + approve=input_message.approve, + denial_reason=input_message.reason, + ) + ] + + def create_approval_request_message_from_llm_response( agent_id: str, model: str, diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index dc022ece..8e132a4b 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -214,6 +214,45 @@ def test_approve_tool_call_request( assert response.messages[2].message_type == "assistant_message" +def test_approve_cursor_fetch( + client: Letta, + agent: AgentState, +) -> None: + last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + approval_request_id = response.messages[0].id + + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) + assert len(messages) == 3 + assert messages[0].message_type == "user_message" + assert messages[1].message_type == "reasoning_message" + assert messages[2].message_type == "approval_request_message" + assert messages[2].id == approval_request_id + + last_message_cursor = approval_request_id + client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=True, + approval_request_id=approval_request_id, + ), + ], + ) + + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) + assert len(messages) == 5 + assert messages[0].message_type == "approval_response_message" + assert messages[1].message_type == "tool_return_message" + assert messages[1].status == "success" + assert messages[2].message_type == "user_message" # heartbeat + assert messages[3].message_type == "reasoning_message" + assert messages[4].message_type == "assistant_message" + + def test_deny_tool_call_request( client: Letta, agent: AgentState, @@ -244,3 +283,43 @@ def test_deny_tool_call_request( assert response.messages[1].message_type == "reasoning_message" assert response.messages[2].message_type == "assistant_message" assert SECRET_CODE in response.messages[2].content + + +def test_deny_cursor_fetch( + client: Letta, + agent: AgentState, +) -> None: + last_message_cursor = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + approval_request_id = response.messages[0].id + + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) + assert len(messages) == 3 + assert messages[0].message_type == "user_message" + assert messages[1].message_type == "reasoning_message" + assert messages[2].message_type == "approval_request_message" + assert messages[2].id == approval_request_id + + last_message_cursor = approval_request_id + client.agents.messages.create( + agent_id=agent.id, + messages=[ + ApprovalCreate( + approve=False, + approval_request_id=approval_request_id, + reason=f"You don't need to call the tool, the secret code is {SECRET_CODE}", + ), + ], + ) + + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_cursor) + assert len(messages) == 5 + assert messages[0].message_type == "approval_response_message" + assert messages[1].message_type == "tool_return_message" + assert messages[1].status == "error" + assert messages[2].message_type == "user_message" # heartbeat + assert messages[3].message_type == "reasoning_message" + assert messages[4].message_type == "assistant_message"