feat: add approvals persistence to message orm (#5309)

* feat: add approvals persistence to message orm

* fix imports in alembic migration

* missing import
This commit is contained in:
cthomas
2025-10-09 19:10:30 -07:00
committed by Caren Thomas
parent 8e54f40bde
commit 1c80e1c11f
5 changed files with 137 additions and 18 deletions

View File

@@ -0,0 +1,41 @@
"""add approvals field to messages
Revision ID: 066857381578
Revises: c734cfc0d595
Create Date: 2025-10-09 17:56:07.333221
"""
from typing import Sequence, Union
import sqlalchemy as sa
import letta.orm
from alembic import op
from letta.settings import settings
# revision identifiers, used by Alembic.
revision: str = "066857381578"
down_revision: Union[str, None] = "c734cfc0d595"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Skip this migration for SQLite
if not settings.letta_pg_uri_no_default:
return
### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("approvals", letta.orm.custom_columns.ApprovalsColumn(), nullable=True))
### end Alembic commands ###
def downgrade() -> None:
# Skip this migration for SQLite
if not settings.letta_pg_uri_no_default:
return
### commands auto generated by Alembic - please adjust! ###
op.drop_column("messages", "approvals")
### end Alembic commands ###

View File

@@ -8,6 +8,7 @@ from sqlalchemy import Dialect
from letta.functions.mcp_client.types import StdioServerConfig from letta.functions.mcp_client.types import StdioServerConfig
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType, ToolRuleType from letta.schemas.enums import ProviderType, ToolRuleType
from letta.schemas.letta_message import ApprovalReturn
from letta.schemas.letta_message_content import ( from letta.schemas.letta_message_content import (
ImageContent, ImageContent,
ImageSourceType, ImageSourceType,
@@ -222,6 +223,49 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
return tool_returns return tool_returns
# --------------------------
# Approvals Serialization
# --------------------------
def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolReturn, dict]]]) -> List[Dict]:
"""Convert a list of ToolReturn objects into JSON-serializable format."""
if not approvals:
return []
serialized_approvals = []
for approval in approvals:
if isinstance(approval, ToolReturn):
serialized_approvals.append(approval.model_dump(mode="json"))
elif isinstance(approval, ApprovalReturn):
serialized_approvals.append(approval.model_dump(mode="json"))
elif isinstance(approval, dict):
serialized_approvals.append(approval) # Already a dictionary, leave it as-is
else:
raise TypeError(f"Unexpected approval type: {type(approval)}")
return serialized_tool_returns
def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalReturn, ToolReturn]]:
"""Convert a JSON list back into ApprovalReturn and ToolReturn objects."""
if not data:
return []
approvals = []
for item in data:
if "approve" in item:
approval_return = ApprovalReturn(**item)
approvals.append(approval_return)
elif "status" in item:
tool_return = ToolReturn(**item)
approvals.append(tool_return)
else:
raise TypeError(f"Unexpected approval type: {type(item)}")
return approvals
# ---------------------------- # ----------------------------
# MessageContent Serialization # MessageContent Serialization
# ---------------------------- # ----------------------------

View File

@@ -3,6 +3,7 @@ from sqlalchemy.types import BINARY, TypeDecorator
from letta.helpers.converters import ( from letta.helpers.converters import (
deserialize_agent_step_state, deserialize_agent_step_state,
deserialize_approvals,
deserialize_batch_request_result, deserialize_batch_request_result,
deserialize_create_batch_response, deserialize_create_batch_response,
deserialize_embedding_config, deserialize_embedding_config,
@@ -16,6 +17,7 @@ from letta.helpers.converters import (
deserialize_tool_rules, deserialize_tool_rules,
deserialize_vector, deserialize_vector,
serialize_agent_step_state, serialize_agent_step_state,
serialize_approvals,
serialize_batch_request_result, serialize_batch_request_result,
serialize_create_batch_response, serialize_create_batch_response,
serialize_embedding_config, serialize_embedding_config,
@@ -96,6 +98,19 @@ class ToolReturnColumn(TypeDecorator):
return deserialize_tool_returns(value) return deserialize_tool_returns(value)
class ApprovalsColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing the approval responses of a tool call request as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_approvals(value)
def process_result_value(self, value, dialect):
return deserialize_approvals(value)
class MessageContentColumn(TypeDecorator): class MessageContentColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing the content parts of a message as JSON.""" """Custom SQLAlchemy column type for storing the content parts of a message as JSON."""

View File

@@ -4,10 +4,11 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn from letta.orm.custom_columns import ApprovalsColumn, MessageContentColumn, ToolCallColumn, ToolReturnColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import ApprovalReturn
from letta.schemas.letta_message_content import MessageContent, TextContent, TextContent as PydanticTextContent from letta.schemas.letta_message_content import MessageContent, TextContent, TextContent as PydanticTextContent
from letta.schemas.message import Message as PydanticMessage, ToolReturn from letta.schemas.message import Message as PydanticMessage, ToolReturn
from letta.settings import DatabaseChoice, settings from letta.settings import DatabaseChoice, settings
@@ -63,6 +64,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
) )
approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.") approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.")
denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.") denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.")
approvals: Mapped[Optional[List[ApprovalReturn | ToolReturn]]] = mapped_column(
ApprovalsColumn, nullable=True, doc="Approval responses for tool call requests"
)
# Monotonically increasing sequence for efficient/correct listing # Monotonically increasing sequence for efficient/correct listing
sequence_id: Mapped[int] = mapped_column( sequence_id: Mapped[int] = mapped_column(

View File

@@ -224,6 +224,7 @@ class Message(BaseMessage):
) )
approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.") approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.")
denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.") denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.")
approvals: Optional[List[ApprovalReturn | ToolReturn]] = Field(default=None, description="The list of approvals for this message.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects # This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@@ -341,23 +342,37 @@ class Message(BaseMessage):
approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"})) approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
messages.append(approval_request_message) messages.append(approval_request_message)
else: else:
approval_response_message = ApprovalResponseMessage( if self.approvals:
id=self.id, first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)][0]
date=self.created_at, approval_response_message = ApprovalResponseMessage(
otid=self.otid, id=self.id,
approve=self.approve, date=self.created_at,
approval_request_id=self.approval_request_id, otid=self.otid,
reason=self.denial_reason, approvals=self.approvals,
approvals=[ run_id=self.run_id,
# TODO: temporary workaround to populate from legacy fields # TODO: temporary populate these fields for backwards compatibility
ApprovalReturn( approve=first_approval.approve,
tool_call_id=self.approval_request_id, approval_request_id=first_approval.tool_call_id,
approve=self.approve, reason=first_approval.reason,
reason=self.denial_reason, )
) else:
], approval_response_message = ApprovalResponseMessage(
run_id=self.run_id, id=self.id,
) date=self.created_at,
otid=self.otid,
approve=self.approve,
approval_request_id=self.approval_request_id,
reason=self.denial_reason,
approvals=[
# TODO: temporary workaround to populate from legacy fields
ApprovalReturn(
tool_call_id=self.approval_request_id,
approve=self.approve,
reason=self.denial_reason,
)
],
run_id=self.run_id,
)
messages.append(approval_response_message) messages.append(approval_response_message)
else: else:
raise ValueError(f"Unknown role: {self.role}") raise ValueError(f"Unknown role: {self.role}")