diff --git a/fern/openapi.json b/fern/openapi.json index 1922eb1a..dac07f7d 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -22323,7 +22323,21 @@ "anyOf": [ { "items": { - "$ref": "#/components/schemas/ApprovalReturn" + "oneOf": [ + { + "$ref": "#/components/schemas/ApprovalReturn" + }, + { + "$ref": "#/components/schemas/letta__schemas__letta_message__ToolReturn" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "approval": "#/components/schemas/ApprovalReturn", + "tool": "#/components/schemas/letta__schemas__letta_message__ToolReturn" + } + } }, "type": "array" }, @@ -22590,7 +22604,21 @@ "anyOf": [ { "items": { - "$ref": "#/components/schemas/ApprovalReturn" + "oneOf": [ + { + "$ref": "#/components/schemas/ApprovalReturn" + }, + { + "$ref": "#/components/schemas/letta__schemas__letta_message__ToolReturn" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "approval": "#/components/schemas/ApprovalReturn", + "tool": "#/components/schemas/letta__schemas__letta_message__ToolReturn" + } + } }, "type": "array" }, @@ -22648,6 +22676,13 @@ }, "ApprovalReturn": { "properties": { + "type": { + "type": "string", + "const": "approval", + "title": "Type", + "description": "The message type to be created.", + "default": "approval" + }, "tool_call_id": { "type": "string", "title": "Tool Call Id", @@ -30055,14 +30090,21 @@ "anyOf": [ { "items": { - "anyOf": [ + "oneOf": [ { "$ref": "#/components/schemas/ApprovalReturn" }, { - "$ref": "#/components/schemas/letta__schemas__message__ToolReturn" + "$ref": "#/components/schemas/letta__schemas__letta_message__ToolReturn" } - ] + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "approval": "#/components/schemas/ApprovalReturn", + "tool": "#/components/schemas/letta__schemas__letta_message__ToolReturn" + } + } }, "type": "array" }, @@ -34380,71 +34422,6 @@ "required": ["name", "description", "parameters"], "title": "ToolJSONSchema" }, - "ToolReturn-Input": { - "properties": { - "tool_call_id": { - "anyOf": [ - {}, - { - "type": "null" - } - ], - "title": "Tool Call Id", - "description": "The ID for the tool call" - }, - "status": { - "type": "string", - "enum": ["success", "error"], - "title": "Status", - "description": "The status of the tool call" - }, - "stdout": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Stdout", - "description": "Captured stdout (e.g. prints, logs) from the tool invocation" - }, - "stderr": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Stderr", - "description": "Captured stderr from the tool invocation" - }, - "func_response": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Func Response", - "description": "The function response string" - } - }, - "type": "object", - "required": ["status"], - "title": "ToolReturn" - }, "ToolReturnContent": { "properties": { "type": { @@ -36775,7 +36752,7 @@ "anyOf": [ { "items": { - "$ref": "#/components/schemas/ToolReturn-Input" + "$ref": "#/components/schemas/letta__schemas__message__ToolReturn" }, "type": "array" }, @@ -37023,6 +37000,13 @@ }, "letta__schemas__letta_message__ToolReturn": { "properties": { + "type": { + "type": "string", + "const": "tool", + "title": "Type", + "description": "The message type to be created.", + "default": "tool" + }, "tool_return": { "type": "string", "title": "Tool Return" diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 1065f105..c067eff9 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -8,7 +8,7 @@ from sqlalchemy import Dialect from letta.functions.mcp_client.types import StdioServerConfig from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ProviderType, ToolRuleType -from letta.schemas.letta_message import ApprovalReturn +from letta.schemas.letta_message import ApprovalReturn, MessageReturnType from letta.schemas.letta_message_content import ( ImageContent, ImageSourceType, @@ -235,9 +235,9 @@ def serialize_approvals(approvals: Optional[List[Union[ApprovalReturn, ToolRetur serialized_approvals = [] for approval in approvals: - if isinstance(approval, ToolReturn): + if isinstance(approval, ApprovalReturn): serialized_approvals.append(approval.model_dump(mode="json")) - elif isinstance(approval, ApprovalReturn): + elif isinstance(approval, ToolReturn): serialized_approvals.append(approval.model_dump(mode="json")) elif isinstance(approval, dict): serialized_approvals.append(approval) # Already a dictionary, leave it as-is @@ -254,14 +254,14 @@ def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalRetu approvals = [] for item in data: - if "approve" in item: + if "type" in item and item.get("type") == MessageReturnType.approval: approval_return = ApprovalReturn(**item) approvals.append(approval_return) elif "status" in item: tool_return = ToolReturn(**item) approvals.append(tool_return) else: - raise TypeError(f"Unexpected approval type: {type(item)}") + continue return approvals diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index e67ba46a..0d104230 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -17,6 +17,34 @@ from letta.schemas.letta_message_content import ( # --------------------------- +class MessageReturnType(str, Enum): + approval = "approval" + tool = "tool" + + +class MessageReturn(BaseModel): + type: MessageReturnType = Field(..., description="The message type to be created.") + + +class ApprovalReturn(MessageReturn): + type: Literal[MessageReturnType.approval] = Field(default=MessageReturnType.approval, description="The message type to be created.") + tool_call_id: str = Field(..., description="The ID of the tool call that corresponds to this approval") + approve: bool = Field(..., description="Whether the tool has been approved") + reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status") + + +class ToolReturn(MessageReturn): + type: Literal[MessageReturnType.tool] = Field(default=MessageReturnType.tool, description="The message type to be created.") + tool_return: str + status: Literal["success", "error"] + tool_call_id: str + stdout: Optional[List[str]] = None + stderr: Optional[List[str]] = None + + +LettaMessageReturnUnion = Annotated[Union[ApprovalReturn, ToolReturn], Field(discriminator="type")] + + class MessageType(str, Enum): system_message = "system_message" user_message = "user_message" @@ -233,14 +261,6 @@ class ToolCallMessage(LettaMessage): return v -class ToolReturn(BaseModel): - tool_return: str - status: Literal["success", "error"] - tool_call_id: str - stdout: Optional[List[str]] = None - stderr: Optional[List[str]] = None - - class ToolReturnMessage(LettaMessage): """ A message representing the return value of a tool call (generated by Letta executing the requested tool). @@ -285,12 +305,6 @@ class ApprovalRequestMessage(LettaMessage): tool_call: Union[ToolCall, ToolCallDelta] = Field(..., description="The tool call that has been requested by the llm to run") -class ApprovalReturn(BaseModel): - tool_call_id: str = Field(..., description="The ID of the tool call that corresponds to this approval") - approve: bool = Field(..., description="Whether the tool has been approved") - reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status") - - class ApprovalResponseMessage(LettaMessage): """ A message representing a response form the user indicating whether a tool has been approved to run. @@ -307,7 +321,7 @@ class ApprovalResponseMessage(LettaMessage): message_type: Literal[MessageType.approval_response_message] = Field( default=MessageType.approval_response_message, description="The type of the message." ) - approvals: Optional[List[ApprovalReturn]] = Field(default=None, description="The list of approval responses") + approvals: Optional[List[LettaMessageReturnUnion]] = Field(default=None, description="The list of approval responses") approve: Optional[bool] = Field(None, description="Whether the tool has been approved", deprecated=True) approval_request_id: Optional[str] = Field(None, description="The message ID of the approval request", deprecated=True) reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status", deprecated=True) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 891e4bf8..b876212d 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -27,6 +27,7 @@ from letta.schemas.letta_message import ( AssistantMessage, HiddenReasoningMessage, LettaMessage, + LettaMessageReturnUnion, MessageType, ReasoningMessage, SystemMessage, @@ -117,7 +118,7 @@ class ApprovalCreate(MessageCreateBase): """Input to approve or deny a tool call request""" type: Literal[MessageCreateType.approval] = Field(default=MessageCreateType.approval, description="The message type to be created.") - approvals: Optional[List[ApprovalReturn]] = Field(default=None, description="The list of approval responses") + approvals: Optional[List[LettaMessageReturnUnion]] = Field(default=None, description="The list of approval responses") approve: Optional[bool] = Field(None, description="Whether the tool has been approved", deprecated=True) approval_request_id: Optional[str] = Field(None, description="The message ID of the approval request", deprecated=True) reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status", deprecated=True) @@ -224,7 +225,7 @@ class Message(BaseMessage): ) approve: Optional[bool] = Field(default=None, description="Whether tool call is approved.") denial_reason: Optional[str] = Field(default=None, description="The reason the tool call request was denied.") - approvals: Optional[List[ApprovalReturn | ToolReturn]] = Field(default=None, description="The list of approvals for this message.") + approvals: Optional[List[LettaMessageReturnUnion]] = Field(default=None, description="The list of approvals for this message.") # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")