feat: add approvals persistence to message orm (#5309)
* feat: add approvals persistence to message orm * fix imports in alembic migration * missing import
This commit is contained in:
@@ -0,0 +1,41 @@
|
|||||||
|
"""add approvals field to messages
|
||||||
|
|
||||||
|
Revision ID: 066857381578
|
||||||
|
Revises: c734cfc0d595
|
||||||
|
Create Date: 2025-10-09 17:56:07.333221
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
import letta.orm
|
||||||
|
from alembic import op
|
||||||
|
from letta.settings import settings
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "066857381578"
|
||||||
|
down_revision: Union[str, None] = "c734cfc0d595"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Skip this migration for SQLite
|
||||||
|
if not settings.letta_pg_uri_no_default:
|
||||||
|
return
|
||||||
|
|
||||||
|
### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column("messages", sa.Column("approvals", letta.orm.custom_columns.ApprovalsColumn(), nullable=True))
|
||||||
|
### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Skip this migration for SQLite
|
||||||
|
if not settings.letta_pg_uri_no_default:
|
||||||
|
return
|
||||||
|
|
||||||
|
### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("messages", "approvals")
|
||||||
|
### end Alembic commands ###
|
||||||
@@ -8,6 +8,7 @@ from sqlalchemy import Dialect
|
|||||||
from letta.functions.mcp_client.types import StdioServerConfig
|
from letta.functions.mcp_client.types import StdioServerConfig
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.enums import ProviderType, ToolRuleType
|
from letta.schemas.enums import ProviderType, ToolRuleType
|
||||||
|
from letta.schemas.letta_message import ApprovalReturn
|
||||||
from letta.schemas.letta_message_content import (
|
from letta.schemas.letta_message_content import (
|
||||||
ImageContent,
|
ImageContent,
|
||||||
ImageSourceType,
|
ImageSourceType,
|
||||||
@@ -222,6 +223,49 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
|
|||||||
return tool_returns
|
return tool_returns
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# Approvals Serialization
|
||||||
|
# --------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolReturn, dict]]]) -> List[Dict]:
|
||||||
|
"""Convert a list of ToolReturn objects into JSON-serializable format."""
|
||||||
|
if not approvals:
|
||||||
|
return []
|
||||||
|
|
||||||
|
serialized_approvals = []
|
||||||
|
for approval in approvals:
|
||||||
|
if isinstance(approval, ToolReturn):
|
||||||
|
serialized_approvals.append(approval.model_dump(mode="json"))
|
||||||
|
elif isinstance(approval, ApprovalReturn):
|
||||||
|
serialized_approvals.append(approval.model_dump(mode="json"))
|
||||||
|
elif isinstance(approval, dict):
|
||||||
|
serialized_approvals.append(approval) # Already a dictionary, leave it as-is
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unexpected approval type: {type(approval)}")
|
||||||
|
|
||||||
|
return serialized_tool_returns
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalReturn, ToolReturn]]:
|
||||||
|
"""Convert a JSON list back into ApprovalReturn and ToolReturn objects."""
|
||||||
|
if not data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
approvals = []
|
||||||
|
for item in data:
|
||||||
|
if "approve" in item:
|
||||||
|
approval_return = ApprovalReturn(**item)
|
||||||
|
approvals.append(approval_return)
|
||||||
|
elif "status" in item:
|
||||||
|
tool_return = ToolReturn(**item)
|
||||||
|
approvals.append(tool_return)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unexpected approval type: {type(item)}")
|
||||||
|
|
||||||
|
return approvals
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# MessageContent Serialization
|
# MessageContent Serialization
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from sqlalchemy.types import BINARY, TypeDecorator
|
|||||||
|
|
||||||
from letta.helpers.converters import (
|
from letta.helpers.converters import (
|
||||||
deserialize_agent_step_state,
|
deserialize_agent_step_state,
|
||||||
|
deserialize_approvals,
|
||||||
deserialize_batch_request_result,
|
deserialize_batch_request_result,
|
||||||
deserialize_create_batch_response,
|
deserialize_create_batch_response,
|
||||||
deserialize_embedding_config,
|
deserialize_embedding_config,
|
||||||
@@ -16,6 +17,7 @@ from letta.helpers.converters import (
|
|||||||
deserialize_tool_rules,
|
deserialize_tool_rules,
|
||||||
deserialize_vector,
|
deserialize_vector,
|
||||||
serialize_agent_step_state,
|
serialize_agent_step_state,
|
||||||
|
serialize_approvals,
|
||||||
serialize_batch_request_result,
|
serialize_batch_request_result,
|
||||||
serialize_create_batch_response,
|
serialize_create_batch_response,
|
||||||
serialize_embedding_config,
|
serialize_embedding_config,
|
||||||
@@ -96,6 +98,19 @@ class ToolReturnColumn(TypeDecorator):
|
|||||||
return deserialize_tool_returns(value)
|
return deserialize_tool_returns(value)
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalsColumn(TypeDecorator):
|
||||||
|
"""Custom SQLAlchemy column type for storing the approval responses of a tool call request as JSON."""
|
||||||
|
|
||||||
|
impl = JSON
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
return serialize_approvals(value)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
return deserialize_approvals(value)
|
||||||
|
|
||||||
|
|
||||||
class MessageContentColumn(TypeDecorator):
|
class MessageContentColumn(TypeDecorator):
|
||||||
"""Custom SQLAlchemy column type for storing the content parts of a message as JSON."""
|
"""Custom SQLAlchemy column type for storing the content parts of a message as JSON."""
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,11 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
|
|||||||
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text
|
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
||||||
|
|
||||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
from letta.orm.custom_columns import ApprovalsColumn, MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
|
from letta.schemas.letta_message import ApprovalReturn
|
||||||
from letta.schemas.letta_message_content import MessageContent, TextContent, TextContent as PydanticTextContent
|
from letta.schemas.letta_message_content import MessageContent, TextContent, TextContent as PydanticTextContent
|
||||||
from letta.schemas.message import Message as PydanticMessage, ToolReturn
|
from letta.schemas.message import Message as PydanticMessage, ToolReturn
|
||||||
from letta.settings import DatabaseChoice, settings
|
from letta.settings import DatabaseChoice, settings
|
||||||
@@ -63,6 +64,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|||||||
)
|
)
|
||||||
approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.")
|
approve: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="Whether tool call is approved.")
|
||||||
denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.")
|
denial_reason: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The reason the tool call request was denied.")
|
||||||
|
approvals: Mapped[Optional[List[ApprovalReturn | ToolReturn]]] = mapped_column(
|
||||||
|
ApprovalsColumn, nullable=True, doc="Approval responses for tool call requests"
|
||||||
|
)
|
||||||
|
|
||||||
# Monotonically increasing sequence for efficient/correct listing
|
# Monotonically increasing sequence for efficient/correct listing
|
||||||
sequence_id: Mapped[int] = mapped_column(
|
sequence_id: Mapped[int] = mapped_column(
|
||||||
|
|||||||
@@ -224,6 +224,7 @@ class Message(BaseMessage):
|
|||||||
)
|
)
|
||||||
approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.")
|
approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.")
|
||||||
denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.")
|
denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.")
|
||||||
|
approvals: Optional[List[ApprovalReturn | ToolReturn]] = Field(default=None, description="The list of approvals for this message.")
|
||||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||||
|
|
||||||
@@ -341,23 +342,37 @@ class Message(BaseMessage):
|
|||||||
approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
|
approval_request_message = ApprovalRequestMessage(**tool_calls[0].model_dump(exclude={"message_type"}))
|
||||||
messages.append(approval_request_message)
|
messages.append(approval_request_message)
|
||||||
else:
|
else:
|
||||||
approval_response_message = ApprovalResponseMessage(
|
if self.approvals:
|
||||||
id=self.id,
|
first_approval = [a for a in self.approvals if isinstance(a, ApprovalReturn)][0]
|
||||||
date=self.created_at,
|
approval_response_message = ApprovalResponseMessage(
|
||||||
otid=self.otid,
|
id=self.id,
|
||||||
approve=self.approve,
|
date=self.created_at,
|
||||||
approval_request_id=self.approval_request_id,
|
otid=self.otid,
|
||||||
reason=self.denial_reason,
|
approvals=self.approvals,
|
||||||
approvals=[
|
run_id=self.run_id,
|
||||||
# TODO: temporary workaround to populate from legacy fields
|
# TODO: temporary populate these fields for backwards compatibility
|
||||||
ApprovalReturn(
|
approve=first_approval.approve,
|
||||||
tool_call_id=self.approval_request_id,
|
approval_request_id=first_approval.tool_call_id,
|
||||||
approve=self.approve,
|
reason=first_approval.reason,
|
||||||
reason=self.denial_reason,
|
)
|
||||||
)
|
else:
|
||||||
],
|
approval_response_message = ApprovalResponseMessage(
|
||||||
run_id=self.run_id,
|
id=self.id,
|
||||||
)
|
date=self.created_at,
|
||||||
|
otid=self.otid,
|
||||||
|
approve=self.approve,
|
||||||
|
approval_request_id=self.approval_request_id,
|
||||||
|
reason=self.denial_reason,
|
||||||
|
approvals=[
|
||||||
|
# TODO: temporary workaround to populate from legacy fields
|
||||||
|
ApprovalReturn(
|
||||||
|
tool_call_id=self.approval_request_id,
|
||||||
|
approve=self.approve,
|
||||||
|
reason=self.denial_reason,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
run_id=self.run_id,
|
||||||
|
)
|
||||||
messages.append(approval_response_message)
|
messages.append(approval_response_message)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown role: {self.role}")
|
raise ValueError(f"Unknown role: {self.role}")
|
||||||
|
|||||||
Reference in New Issue
Block a user