feat: add approvals persistence to message orm (#5309)
* feat: add approvals persistence to message orm * fix imports in alembic migration * missing import
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
"""add approvals field to messages
|
||||
|
||||
Revision ID: 066857381578
|
||||
Revises: c734cfc0d595
|
||||
Create Date: 2025-10-09 17:56:07.333221
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
import letta.orm
|
||||
from alembic import op
|
||||
from letta.settings import settings
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "066857381578"
|
||||
down_revision: Union[str, None] = "c734cfc0d595"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Skip this migration for SQLite
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
return
|
||||
|
||||
### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("messages", sa.Column("approvals", letta.orm.custom_columns.ApprovalsColumn(), nullable=True))
|
||||
### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Skip this migration for SQLite
|
||||
if not settings.letta_pg_uri_no_default:
|
||||
return
|
||||
|
||||
### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("messages", "approvals")
|
||||
### end Alembic commands ###
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy import Dialect
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderType, ToolRuleType
|
||||
from letta.schemas.letta_message import ApprovalReturn
|
||||
from letta.schemas.letta_message_content import (
|
||||
ImageContent,
|
||||
ImageSourceType,
|
||||
@@ -222,6 +223,49 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
|
||||
return tool_returns
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Approvals Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolReturn, dict]]]) -> List[Dict]:
|
||||
"""Convert a list of ToolReturn objects into JSON-serializable format."""
|
||||
if not approvals:
|
||||
return []
|
||||
|
||||
serialized_approvals = []
|
||||
for approval in approvals:
|
||||
if isinstance(approval, ToolReturn):
|
||||
serialized_approvals.append(approval.model_dump(mode="json"))
|
||||
elif isinstance(approval, ApprovalReturn):
|
||||
serialized_approvals.append(approval.model_dump(mode="json"))
|
||||
elif isinstance(approval, dict):
|
||||
serialized_approvals.append(approval) # Already a dictionary, leave it as-is
|
||||
else:
|
||||
raise TypeError(f"Unexpected approval type: {type(approval)}")
|
||||
|
||||
return serialized_tool_returns
|
||||
|
||||
|
||||
def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalReturn, ToolReturn]]:
|
||||
"""Convert a JSON list back into ApprovalReturn and ToolReturn objects."""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
approvals = []
|
||||
for item in data:
|
||||
if "approve" in item:
|
||||
approval_return = ApprovalReturn(**item)
|
||||
approvals.append(approval_return)
|
||||
elif "status" in item:
|
||||
tool_return = ToolReturn(**item)
|
||||
approvals.append(tool_return)
|
||||
else:
|
||||
raise TypeError(f"Unexpected approval type: {type(item)}")
|
||||
|
||||
return approvals
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# MessageContent Serialization
|
||||
# ----------------------------
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy.types import BINARY, TypeDecorator
|
||||
|
||||
from letta.helpers.converters import (
|
||||
deserialize_agent_step_state,
|
||||
deserialize_approvals,
|
||||
deserialize_batch_request_result,
|
||||
deserialize_create_batch_response,
|
||||
deserialize_embedding_config,
|
||||
@@ -16,6 +17,7 @@ from letta.helpers.converters import (
|
||||
deserialize_tool_rules,
|
||||
deserialize_vector,
|
||||
serialize_agent_step_state,
|
||||
serialize_approvals,
|
||||
serialize_batch_request_result,
|
||||
serialize_create_batch_response,
|
||||
serialize_embedding_config,
|
||||
@@ -96,6 +98,19 @@ class ToolReturnColumn(TypeDecorator):
|
||||
return deserialize_tool_returns(value)
|
||||
|
||||
|
||||
class ApprovalsColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing the approval responses of a tool call request as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_approvals(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_approvals(value)
|
||||
|
||||
|
||||
class MessageContentColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing the content parts of a message as JSON."""
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
|
||||
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.custom_columns import ApprovalsColumn, MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import ApprovalReturn
|
||||
from letta.schemas.letta_message_content import MessageContent, TextContent, TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage, ToolReturn
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
@@ -63,6 +64,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
)
|
||||
approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.")
|
||||
denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.")
|
||||
approvals: Mapped[Optional[List[ApprovalReturn | ToolReturn]]] = mapped_column(
|
||||
ApprovalsColumn, nullable=True, doc="Approval responses for tool call requests"
|
||||
)
|
||||
|
||||
# Monotonically increasing sequence for efficient/correct listing
|
||||
sequence_id: Mapped[int] = mapped_column(
|
||||
|
||||
@@ -224,6 +224,7 @@ class Message(BaseMessage):
|
||||
)
|
||||
approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.")
|
||||
denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.")
|
||||
approvals: Optional[List[ApprovalReturn | ToolReturn]] = Field(default=None, description="The list of approvals for this message.")
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@@ -341,23 +342,37 @@ class Message(BaseMessage):
|
||||
approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
|
||||
messages.append(approval_request_message)
|
||||
else:
|
||||
approval_response_message = ApprovalResponseMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
otid=self.otid,
|
||||
approve=self.approve,
|
||||
approval_request_id=self.approval_request_id,
|
||||
reason=self.denial_reason,
|
||||
approvals=[
|
||||
# TODO: temporary workaround to populate from legacy fields
|
||||
ApprovalReturn(
|
||||
tool_call_id=self.approval_request_id,
|
||||
approve=self.approve,
|
||||
reason=self.denial_reason,
|
||||
)
|
||||
],
|
||||
run_id=self.run_id,
|
||||
)
|
||||
if self.approvals:
|
||||
first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)][0]
|
||||
approval_response_message = ApprovalResponseMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
otid=self.otid,
|
||||
approvals=self.approvals,
|
||||
run_id=self.run_id,
|
||||
# TODO: temporary populate these fields for backwards compatibility
|
||||
approve=first_approval.approve,
|
||||
approval_request_id=first_approval.tool_call_id,
|
||||
reason=first_approval.reason,
|
||||
)
|
||||
else:
|
||||
approval_response_message = ApprovalResponseMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
otid=self.otid,
|
||||
approve=self.approve,
|
||||
approval_request_id=self.approval_request_id,
|
||||
reason=self.denial_reason,
|
||||
approvals=[
|
||||
# TODO: temporary workaround to populate from legacy fields
|
||||
ApprovalReturn(
|
||||
tool_call_id=self.approval_request_id,
|
||||
approve=self.approve,
|
||||
reason=self.denial_reason,
|
||||
)
|
||||
],
|
||||
run_id=self.run_id,
|
||||
)
|
||||
messages.append(approval_response_message)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {self.role}")
|
||||
|
||||
Reference in New Issue
Block a user