diff --git a/alembic/versions/066857381578_add_approvals_field_to_messages.py b/alembic/versions/066857381578_add_approvals_field_to_messages.py new file mode 100644 index 00000000..ad882852 --- /dev/null +++ b/alembic/versions/066857381578_add_approvals_field_to_messages.py @@ -0,0 +1,41 @@ +"""add approvals field to messages + +Revision ID: 066857381578 +Revises: c734cfc0d595 +Create Date: 2025-10-09 17:56:07.333221 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +import letta.orm +from alembic import op +from letta.settings import settings + +# revision identifiers, used by Alembic. +revision: str = "066857381578" +down_revision: Union[str, None] = "c734cfc0d595" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Skip this migration for SQLite + if not settings.letta_pg_uri_no_default: + return + + ### commands auto generated by Alembic - please adjust! ### + op.add_column("messages", sa.Column("approvals", letta.orm.custom_columns.ApprovalsColumn(), nullable=True)) + ### end Alembic commands ### + + +def downgrade() -> None: + # Skip this migration for SQLite + if not settings.letta_pg_uri_no_default: + return + + ### commands auto generated by Alembic - please adjust! ### + op.drop_column("messages", "approvals") + ### end Alembic commands ### diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 6ded1937..3deec412 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -8,6 +8,7 @@ from sqlalchemy import Dialect from letta.functions.mcp_client.types import StdioServerConfig from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ProviderType, ToolRuleType +from letta.schemas.letta_message import ApprovalReturn from letta.schemas.letta_message_content import ( ImageContent, ImageSourceType, @@ -222,6 +223,49 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]: return tool_returns +# -------------------------- +# Approvals Serialization +# -------------------------- + + +def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolReturn, dict]]]) -> List[Dict]: + """Convert a list of ToolReturn objects into JSON-serializable format.""" + if not approvals: + return [] + + serialized_approvals = [] + for approval in approvals: + if isinstance(approval, ToolReturn): + serialized_approvals.append(approval.model_dump(mode="json")) + elif isinstance(approval, ApprovalReturn): + serialized_approvals.append(approval.model_dump(mode="json")) + elif isinstance(approval, dict): + serialized_approvals.append(approval) # Already a dictionary, leave it as-is + else: + raise TypeError(f"Unexpected approval type: {type(approval)}") + + return serialized_tool_returns + + +def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalReturn, ToolReturn]]: + """Convert a JSON list back into ApprovalReturn and ToolReturn objects.""" + if not data: + return [] + + approvals = [] + for item in data: + if "approve" in item: + approval_return = ApprovalReturn(**item) + approvals.append(approval_return) + elif "status" in item: + tool_return = ToolReturn(**item) + approvals.append(tool_return) + else: + raise TypeError(f"Unexpected approval type: {type(item)}") + + return approvals + + # ---------------------------- # MessageContent Serialization # ---------------------------- diff --git a/letta/orm/custom_columns.py b/letta/orm/custom_columns.py index 686b35dc..56cbc7e5 100644 --- a/letta/orm/custom_columns.py +++ b/letta/orm/custom_columns.py @@ -3,6 +3,7 @@ from sqlalchemy.types import BINARY, TypeDecorator from letta.helpers.converters import ( deserialize_agent_step_state, + deserialize_approvals, deserialize_batch_request_result, deserialize_create_batch_response, deserialize_embedding_config, @@ -16,6 +17,7 @@ from letta.helpers.converters import ( deserialize_tool_rules, deserialize_vector, serialize_agent_step_state, + serialize_approvals, serialize_batch_request_result, serialize_create_batch_response, serialize_embedding_config, @@ -96,6 +98,19 @@ class ToolReturnColumn(TypeDecorator): return deserialize_tool_returns(value) +class ApprovalsColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing the approval responses of a tool call request as JSON.""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + return serialize_approvals(value) + + def process_result_value(self, value, dialect): + return deserialize_approvals(value) + + class MessageContentColumn(TypeDecorator): """Custom SQLAlchemy column type for storing the content parts of a message as JSON.""" diff --git a/letta/orm/message.py b/letta/orm/message.py index e7ddda0e..770ae8d7 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -4,10 +4,11 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text from sqlalchemy.orm import Mapped, Session, mapped_column, relationship -from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn +from letta.orm.custom_columns import ApprovalsColumn, MessageContentColumn, ToolCallColumn, ToolReturnColumn from letta.orm.mixins import AgentMixin, OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import ApprovalReturn from letta.schemas.letta_message_content import MessageContent, TextContent, TextContent as PydanticTextContent from letta.schemas.message import Message as PydanticMessage, ToolReturn from letta.settings import DatabaseChoice, settings @@ -63,6 +64,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): ) approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.") denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.") + approvals: Mapped[Optional[List[ApprovalReturn | ToolReturn]]] = mapped_column( + ApprovalsColumn, nullable=True, doc="Approval responses for tool call requests" + ) # Monotonically increasing sequence for efficient/correct listing sequence_id: Mapped[int] = mapped_column( diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 4f456bfa..9f2d7e57 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -224,6 +224,7 @@ class Message(BaseMessage): ) approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.") denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.") + approvals: Optional[List[ApprovalReturn | ToolReturn]] = Field(default=None, description="The list of approvals for this message.") # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") @@ -341,23 +342,37 @@ class Message(BaseMessage): approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"})) messages.append(approval_request_message) else: - approval_response_message = ApprovalResponseMessage( - id=self.id, - date=self.created_at, - otid=self.otid, - approve=self.approve, - approval_request_id=self.approval_request_id, - reason=self.denial_reason, - approvals=[ - # TODO: temporary workaround to populate from legacy fields - ApprovalReturn( - tool_call_id=self.approval_request_id, - approve=self.approve, - reason=self.denial_reason, - ) - ], - run_id=self.run_id, - ) + if self.approvals: + first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)][0] + approval_response_message = ApprovalResponseMessage( + id=self.id, + date=self.created_at, + otid=self.otid, + approvals=self.approvals, + run_id=self.run_id, + # TODO: temporary populate these fields for backwards compatibility + approve=first_approval.approve, + approval_request_id=first_approval.tool_call_id, + reason=first_approval.reason, + ) + else: + approval_response_message = ApprovalResponseMessage( + id=self.id, + date=self.created_at, + otid=self.otid, + approve=self.approve, + approval_request_id=self.approval_request_id, + reason=self.denial_reason, + approvals=[ + # TODO: temporary workaround to populate from legacy fields + ApprovalReturn( + tool_call_id=self.approval_request_id, + approve=self.approve, + reason=self.denial_reason, + ) + ], + run_id=self.run_id, + ) messages.append(approval_response_message) else: raise ValueError(f"Unknown role: {self.role}")