feat: add discriminator type to message return objects (#5318)

This commit is contained in:
cthomas
2025-10-10 11:00:16 -07:00
committed by Caren Thomas
parent bbc3de5845
commit 6fd6232992
4 changed files with 92 additions and 93 deletions

View File

@@ -22323,7 +22323,21 @@
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/ApprovalReturn"
"oneOf": [
{
"$ref": "#/components/schemas/ApprovalReturn"
},
{
"$ref": "#/components/schemas/letta__schemas__letta_message__ToolReturn"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"approval": "#/components/schemas/ApprovalReturn",
"tool": "#/components/schemas/letta__schemas__letta_message__ToolReturn"
}
}
},
"type": "array"
},
@@ -22590,7 +22604,21 @@
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/ApprovalReturn"
"oneOf": [
{
"$ref": "#/components/schemas/ApprovalReturn"
},
{
"$ref": "#/components/schemas/letta__schemas__letta_message__ToolReturn"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"approval": "#/components/schemas/ApprovalReturn",
"tool": "#/components/schemas/letta__schemas__letta_message__ToolReturn"
}
}
},
"type": "array"
},
@@ -22648,6 +22676,13 @@
},
"ApprovalReturn": {
"properties": {
"type": {
"type": "string",
"const": "approval",
"title": "Type",
"description": "The message type to be created.",
"default": "approval"
},
"tool_call_id": {
"type": "string",
"title": "Tool Call Id",
@@ -30055,14 +30090,21 @@
"anyOf": [
{
"items": {
"anyOf": [
"oneOf": [
{
"$ref": "#/components/schemas/ApprovalReturn"
},
{
"$ref": "#/components/schemas/letta__schemas__message__ToolReturn"
"$ref": "#/components/schemas/letta__schemas__letta_message__ToolReturn"
}
]
],
"discriminator": {
"propertyName": "type",
"mapping": {
"approval": "#/components/schemas/ApprovalReturn",
"tool": "#/components/schemas/letta__schemas__letta_message__ToolReturn"
}
}
},
"type": "array"
},
@@ -34380,71 +34422,6 @@
"required": ["name", "description", "parameters"],
"title": "ToolJSONSchema"
},
"ToolReturn-Input": {
"properties": {
"tool_call_id": {
"anyOf": [
{},
{
"type": "null"
}
],
"title": "Tool Call Id",
"description": "The ID for the tool call"
},
"status": {
"type": "string",
"enum": ["success", "error"],
"title": "Status",
"description": "The status of the tool call"
},
"stdout": {
"anyOf": [
{
"items": {
"type": "string"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Stdout",
"description": "Captured stdout (e.g. prints, logs) from the tool invocation"
},
"stderr": {
"anyOf": [
{
"items": {
"type": "string"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Stderr",
"description": "Captured stderr from the tool invocation"
},
"func_response": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Func Response",
"description": "The function response string"
}
},
"type": "object",
"required": ["status"],
"title": "ToolReturn"
},
"ToolReturnContent": {
"properties": {
"type": {
@@ -36775,7 +36752,7 @@
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/ToolReturn-Input"
"$ref": "#/components/schemas/letta__schemas__message__ToolReturn"
},
"type": "array"
},
@@ -37023,6 +37000,13 @@
},
"letta__schemas__letta_message__ToolReturn": {
"properties": {
"type": {
"type": "string",
"const": "tool",
"title": "Type",
"description": "The message type to be created.",
"default": "tool"
},
"tool_return": {
"type": "string",
"title": "Tool Return"

View File

@@ -8,7 +8,7 @@ from sqlalchemy import Dialect
from letta.functions.mcp_client.types import StdioServerConfig
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType, ToolRuleType
from letta.schemas.letta_message import ApprovalReturn
from letta.schemas.letta_message import ApprovalReturn, MessageReturnType
from letta.schemas.letta_message_content import (
ImageContent,
ImageSourceType,
@@ -235,9 +235,9 @@ def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolRetur
serialized_approvals = []
for approval in approvals:
if isinstance(approval, ToolReturn):
if isinstance(approval, ApprovalReturn):
serialized_approvals.append(approval.model_dump(mode="json"))
elif isinstance(approval, ApprovalReturn):
elif isinstance(approval, ToolReturn):
serialized_approvals.append(approval.model_dump(mode="json"))
elif isinstance(approval, dict):
serialized_approvals.append(approval) # Already a dictionary, leave it as-is
@@ -254,14 +254,14 @@ def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalRetu
approvals = []
for item in data:
if "approve" in item:
if "type" in item and item.get("type") == MessageReturnType.approval:
approval_return = ApprovalReturn(**item)
approvals.append(approval_return)
elif "status" in item:
tool_return = ToolReturn(**item)
approvals.append(tool_return)
else:
raise TypeError(f"Unexpected approval type: {type(item)}")
continue
return approvals

View File

@@ -17,6 +17,34 @@ from letta.schemas.letta_message_content import (
# ---------------------------
class MessageReturnType(str, Enum):
approval = "approval"
tool = "tool"
class MessageReturn(BaseModel):
type: MessageReturnType = Field(..., description="The message type to be created.")
class ApprovalReturn(MessageReturn):
type: Literal[MessageReturnType.approval] = Field(default=MessageReturnType.approval, description="The message type to be created.")
tool_call_id: str = Field(..., description="The ID of the tool call that corresponds to this approval")
approve: bool = Field(..., description="Whether the tool has been approved")
reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status")
class ToolReturn(MessageReturn):
type: Literal[MessageReturnType.tool] = Field(default=MessageReturnType.tool, description="The message type to be created.")
tool_return: str
status: Literal["success", "error"]
tool_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None
LettaMessageReturnUnion = Annotated[Union[ApprovalReturn, ToolReturn], Field(discriminator="type")]
class MessageType(str, Enum):
system_message = "system_message"
user_message = "user_message"
@@ -233,14 +261,6 @@ class ToolCallMessage(LettaMessage):
return v
class ToolReturn(BaseModel):
tool_return: str
status: Literal["success", "error"]
tool_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None
class ToolReturnMessage(LettaMessage):
"""
A message representing the return value of a tool call (generated by Letta executing the requested tool).
@@ -285,12 +305,6 @@ class ApprovalRequestMessage(LettaMessage):
tool_call: Union[ToolCall, ToolCallDelta] = Field(..., description="The tool call that has been requested by the llm to run")
class ApprovalReturn(BaseModel):
tool_call_id: str = Field(..., description="The ID of the tool call that corresponds to this approval")
approve: bool = Field(..., description="Whether the tool has been approved")
reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status")
class ApprovalResponseMessage(LettaMessage):
"""
A message representing a response form the user indicating whether a tool has been approved to run.
@@ -307,7 +321,7 @@ class ApprovalResponseMessage(LettaMessage):
message_type: Literal[MessageType.approval_response_message] = Field(
default=MessageType.approval_response_message, description="The type of the message."
)
approvals: Optional[List[ApprovalReturn]] = Field(default=None, description="The list of approval responses")
approvals: Optional[List[LettaMessageReturnUnion]] = Field(default=None, description="The list of approval responses")
approve: Optional[bool] = Field(None, description="Whether the tool has been approved", deprecated=True)
approval_request_id: Optional[str] = Field(None, description="The message ID of the approval request", deprecated=True)
reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status", deprecated=True)

View File

@@ -27,6 +27,7 @@ from letta.schemas.letta_message import (
AssistantMessage,
HiddenReasoningMessage,
LettaMessage,
LettaMessageReturnUnion,
MessageType,
ReasoningMessage,
SystemMessage,
@@ -117,7 +118,7 @@ class ApprovalCreate(MessageCreateBase):
"""Input to approve or deny a tool call request"""
type: Literal[MessageCreateType.approval] = Field(default=MessageCreateType.approval, description="The message type to be created.")
approvals: Optional[List[ApprovalReturn]] = Field(default=None, description="The list of approval responses")
approvals: Optional[List[LettaMessageReturnUnion]] = Field(default=None, description="The list of approval responses")
approve: Optional[bool] = Field(None, description="Whether the tool has been approved", deprecated=True)
approval_request_id: Optional[str] = Field(None, description="The message ID of the approval request", deprecated=True)
reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status", deprecated=True)
@@ -224,7 +225,7 @@ class Message(BaseMessage):
)
approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.")
denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.")
approvals: Optional[List[ApprovalReturn | ToolReturn]] = Field(default=None, description="The list of approvals for this message.")
approvals: Optional[List[LettaMessageReturnUnion]] = Field(default=None, description="The list of approvals for this message.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")