feat: add approvals persistence to message orm (#5309)

* feat: add approvals persistence to message orm

* fix imports in alembic migration

* missing import
This commit is contained in:
cthomas
2025-10-09 19:10:30 -07:00
committed by Caren Thomas
parent 8e54f40bde
commit 1c80e1c11f
5 changed files with 137 additions and 18 deletions

View File

@@ -0,0 +1,41 @@
"""add approvals field to messages
Revision ID: 066857381578
Revises: c734cfc0d595
Create Date: 2025-10-09 17:56:07.333221
"""
from typing import Sequence, Union
import sqlalchemy as sa
import letta.orm
from alembic import op
from letta.settings import settings
# revision identifiers, used by Alembic.
revision: str = "066857381578"
down_revision: Union[str, None] = "c734cfc0d595"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Skip this migration for SQLite
if not settings.letta_pg_uri_no_default:
return
### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("approvals", letta.orm.custom_columns.ApprovalsColumn(), nullable=True))
### end Alembic commands ###
def downgrade() -> None:
# Skip this migration for SQLite
if not settings.letta_pg_uri_no_default:
return
### commands auto generated by Alembic - please adjust! ###
op.drop_column("messages", "approvals")
### end Alembic commands ###

View File

@@ -8,6 +8,7 @@ from sqlalchemy import Dialect
from letta.functions.mcp_client.types import StdioServerConfig
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType, ToolRuleType
from letta.schemas.letta_message import ApprovalReturn
from letta.schemas.letta_message_content import (
ImageContent,
ImageSourceType,
@@ -222,6 +223,49 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
return tool_returns
# --------------------------
# Approvals Serialization
# --------------------------
def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolReturn, dict]]]) -> List[Dict]:
"""Convert a list of ToolReturn objects into JSON-serializable format."""
if not approvals:
return []
serialized_approvals = []
for approval in approvals:
if isinstance(approval, ToolReturn):
serialized_approvals.append(approval.model_dump(mode="json"))
elif isinstance(approval, ApprovalReturn):
serialized_approvals.append(approval.model_dump(mode="json"))
elif isinstance(approval, dict):
serialized_approvals.append(approval) # Already a dictionary, leave it as-is
else:
raise TypeError(f"Unexpected approval type: {type(approval)}")
return serialized_tool_returns
def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalReturn, ToolReturn]]:
"""Convert a JSON list back into ApprovalReturn and ToolReturn objects."""
if not data:
return []
approvals = []
for item in data:
if "approve" in item:
approval_return = ApprovalReturn(**item)
approvals.append(approval_return)
elif "status" in item:
tool_return = ToolReturn(**item)
approvals.append(tool_return)
else:
raise TypeError(f"Unexpected approval type: {type(item)}")
return approvals
# ----------------------------
# MessageContent Serialization
# ----------------------------

View File

@@ -3,6 +3,7 @@ from sqlalchemy.types import BINARY, TypeDecorator
from letta.helpers.converters import (
deserialize_agent_step_state,
deserialize_approvals,
deserialize_batch_request_result,
deserialize_create_batch_response,
deserialize_embedding_config,
@@ -16,6 +17,7 @@ from letta.helpers.converters import (
deserialize_tool_rules,
deserialize_vector,
serialize_agent_step_state,
serialize_approvals,
serialize_batch_request_result,
serialize_create_batch_response,
serialize_embedding_config,
@@ -96,6 +98,19 @@ class ToolReturnColumn(TypeDecorator):
return deserialize_tool_returns(value)
class ApprovalsColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing the approval responses of a tool call request as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_approvals(value)
def process_result_value(self, value, dialect):
return deserialize_approvals(value)
class MessageContentColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing the content parts of a message as JSON."""

View File

@@ -4,10 +4,11 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
from letta.orm.custom_columns import ApprovalsColumn, MessageContentColumn, ToolCallColumn, ToolReturnColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import ApprovalReturn
from letta.schemas.letta_message_content import MessageContent, TextContent, TextContent as PydanticTextContent
from letta.schemas.message import Message as PydanticMessage, ToolReturn
from letta.settings import DatabaseChoice, settings
@@ -63,6 +64,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
)
approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.")
denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.")
approvals: Mapped[Optional[List[ApprovalReturn | ToolReturn]]] = mapped_column(
ApprovalsColumn, nullable=True, doc="Approval responses for tool call requests"
)
# Monotonically increasing sequence for efficient/correct listing
sequence_id: Mapped[int] = mapped_column(

View File

@@ -224,6 +224,7 @@ class Message(BaseMessage):
)
approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.")
denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.")
approvals: Optional[List[ApprovalReturn | ToolReturn]] = Field(default=None, description="The list of approvals for this message.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@@ -341,23 +342,37 @@ class Message(BaseMessage):
approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
messages.append(approval_request_message)
else:
approval_response_message = ApprovalResponseMessage(
id=self.id,
date=self.created_at,
otid=self.otid,
approve=self.approve,
approval_request_id=self.approval_request_id,
reason=self.denial_reason,
approvals=[
# TODO: temporary workaround to populate from legacy fields
ApprovalReturn(
tool_call_id=self.approval_request_id,
approve=self.approve,
reason=self.denial_reason,
)
],
run_id=self.run_id,
)
if self.approvals:
first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)][0]
approval_response_message = ApprovalResponseMessage(
id=self.id,
date=self.created_at,
otid=self.otid,
approvals=self.approvals,
run_id=self.run_id,
# TODO: temporary populate these fields for backwards compatibility
approve=first_approval.approve,
approval_request_id=first_approval.tool_call_id,
reason=first_approval.reason,
)
else:
approval_response_message = ApprovalResponseMessage(
id=self.id,
date=self.created_at,
otid=self.otid,
approve=self.approve,
approval_request_id=self.approval_request_id,
reason=self.denial_reason,
approvals=[
# TODO: temporary workaround to populate from legacy fields
ApprovalReturn(
tool_call_id=self.approval_request_id,
approve=self.approve,
reason=self.denial_reason,
)
],
run_id=self.run_id,
)
messages.append(approval_response_message)
else:
raise ValueError(f"Unknown role: {self.role}")