Files
letta-server/letta/schemas/letta_message.py
Sarah Wooders f61a4b8319 Revert "feat: add input data to pydantic validation error logging" (#5847)
Revert "feat: add input data to pydantic validation error logging (#5748)"

This reverts commit 0a61c0c2c1fa0e09867120af93f17ab6304a795f.
2025-11-13 15:36:14 -08:00

523 lines
21 KiB
Python

import json
from datetime import datetime, timezone
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
from letta.schemas.letta_message_content import (
LettaAssistantMessageContentUnion,
LettaUserMessageContentUnion,
get_letta_assistant_message_content_union_str_json_schema,
get_letta_user_message_content_union_str_json_schema,
)
# ---------------------------
# Letta API Messaging Schemas
# ---------------------------
class MessageReturnType(str, Enum):
approval = "approval"
tool = "tool"
class MessageReturn(BaseModel):
type: MessageReturnType = Field(..., description="The message type to be created.")
class ApprovalReturn(MessageReturn):
type: Literal[MessageReturnType.approval] = Field(default=MessageReturnType.approval, description="The message type to be created.")
tool_call_id: str = Field(..., description="The ID of the tool call that corresponds to this approval")
approve: bool = Field(..., description="Whether the tool has been approved")
reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status")
class ToolReturn(MessageReturn):
type: Literal[MessageReturnType.tool] = Field(default=MessageReturnType.tool, description="The message type to be created.")
tool_return: str
status: Literal["success", "error"]
tool_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None
LettaMessageReturnUnion = Annotated[Union[ApprovalReturn, ToolReturn], Field(discriminator="type")]
class MessageType(str, Enum):
system_message = "system_message"
user_message = "user_message"
assistant_message = "assistant_message"
reasoning_message = "reasoning_message"
hidden_reasoning_message = "hidden_reasoning_message"
tool_call_message = "tool_call_message"
tool_return_message = "tool_return_message"
approval_request_message = "approval_request_message"
approval_response_message = "approval_response_message"
class LettaMessage(BaseModel):
"""
Base class for simplified Letta message response type. This is intended to be used for developers
who want the internal monologue, tool calls, and tool returns in a simplified format that does not
include additional information other than the content and timestamp.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
message_type (MessageType): The type of the message
otid (Optional[str]): The offline threading id associated with this message
sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id
step_id (Optional[str]): The step id associated with the message
is_err (Optional[bool]): Whether the message is an errored message or not. Used for debugging purposes only.
"""
id: str
date: datetime
name: str | None = None
message_type: MessageType = Field(..., description="The type of the message.")
otid: str | None = None
sender_id: str | None = None
step_id: str | None = None
is_err: bool | None = None
seq_id: int | None = None
run_id: str | None = None
@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
"""
Remove microseconds since it seems like we're inconsistent with getting them
TODO: figure out why we don't always get microseconds (get_utc_time() does)
"""
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.isoformat(timespec="seconds")
@field_serializer("is_err", mode="wrap")
def serialize_is_err(self, value: bool | None, handler, _info):
"""
Only serialize is_err field when it's True (for debugging purposes).
When is_err is None or False, this field will be excluded from the JSON output.
"""
return handler(value) if value is True else None
class SystemMessage(LettaMessage):
"""
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
content (str): The message content sent by the system
"""
message_type: Literal[MessageType.system_message] = Field(default=MessageType.system_message, description="The type of the message.")
content: str = Field(..., description="The message content sent by the system")
class UserMessage(LettaMessage):
"""
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts)
"""
message_type: Literal[MessageType.user_message] = Field(default=MessageType.user_message, description="The type of the message.")
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
...,
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
)
class ReasoningMessage(LettaMessage):
"""
Representation of an agent's internal reasoning.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
source (Literal["reasoner_model", "non_reasoner_model"]): Whether the reasoning
content was generated natively by a reasoner model or derived via prompting
reasoning (str): The internal reasoning of the agent
signature (Optional[str]): The model-generated signature of the reasoning step
"""
message_type: Literal[MessageType.reasoning_message] = Field(
default=MessageType.reasoning_message, description="The type of the message."
)
source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model"
reasoning: str
signature: Optional[str] = None
class HiddenReasoningMessage(LettaMessage):
"""
Representation of an agent's internal reasoning where reasoning content
has been hidden from the response.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
state (Literal["redacted", "omitted"]): Whether the reasoning
content was redacted by the provider or simply omitted by the API
hidden_reasoning (Optional[str]): The internal reasoning of the agent
"""
message_type: Literal[MessageType.hidden_reasoning_message] = Field(
default=MessageType.hidden_reasoning_message, description="The type of the message."
)
state: Literal["redacted", "omitted"]
hidden_reasoning: Optional[str] = None
class ToolCall(BaseModel):
name: str
arguments: str
tool_call_id: str
class ToolCallDelta(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
tool_call_id: Optional[str] = None
def model_dump(self, *args, **kwargs):
"""
This is a workaround to exclude None values from the JSON dump since the
OpenAI style of returning chunks doesn't include keys with null values.
"""
kwargs["exclude_none"] = True
return super().model_dump(*args, **kwargs)
def json(self, *args, **kwargs):
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)
class ToolCallMessage(LettaMessage):
"""
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
"""
message_type: Literal[MessageType.tool_call_message] = Field(
default=MessageType.tool_call_message, description="The type of the message."
)
tool_call: Union[ToolCall, ToolCallDelta] = Field(..., deprecated=True)
tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = None
def model_dump(self, *args, **kwargs):
"""
Handling for the ToolCallDelta exclude_none to work correctly
"""
kwargs["exclude_none"] = True
data = super().model_dump(*args, **kwargs)
if isinstance(data.get("tool_call"), dict):
data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
if isinstance(data.get("tool_calls"), dict):
data["tool_calls"] = {k: v for k, v in data["tool_calls"].items() if v is not None}
elif isinstance(data.get("tool_calls"), list):
data["tool_calls"] = [
{k: v for k, v in item.items() if v is not None} if isinstance(item, dict) else item for item in data["tool_calls"]
]
return data
class Config:
json_encoders = {
ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
ToolCall: lambda v: v.model_dump(exclude_none=True),
}
@field_validator("tool_call", mode="before")
@classmethod
def validate_tool_call(cls, v):
"""
Casts dicts into ToolCallMessage objects. Without this extra validator, Pydantic will throw
an error if 'name' or 'arguments' are None instead of properly casting to ToolCallDelta
instead of ToolCall.
"""
if isinstance(v, dict):
if "name" in v and "arguments" in v and "tool_call_id" in v:
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
elif "name" in v or "arguments" in v or "tool_call_id" in v:
return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id"))
else:
raise ValueError("tool_call must contain either 'name' or 'arguments'")
return v
class ToolReturnMessage(LettaMessage):
"""
A message representing the return value of a tool call (generated by Letta executing the requested tool).
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
tool_return (str): The return value of the tool (deprecated, use tool_returns)
status (Literal["success", "error"]): The status of the tool call (deprecated, use tool_returns)
tool_call_id (str): A unique identifier for the tool call that generated this message (deprecated, use tool_returns)
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation (deprecated, use tool_returns)
stderr (Optional[List(str)]): Captured stderr from the tool invocation (deprecated, use tool_returns)
tool_returns (Optional[List[ToolReturn]]): List of tool returns for multi-tool support
"""
message_type: Literal[MessageType.tool_return_message] = Field(
default=MessageType.tool_return_message, description="The type of the message."
)
tool_return: str = Field(..., deprecated=True)
status: Literal["success", "error"] = Field(..., deprecated=True)
tool_call_id: str = Field(..., deprecated=True)
stdout: Optional[List[str]] = Field(None, deprecated=True)
stderr: Optional[List[str]] = Field(None, deprecated=True)
tool_returns: Optional[List[ToolReturn]] = None
class ApprovalRequestMessage(LettaMessage):
"""
A message representing a request for approval to call a tool (generated by the LLM to trigger tool execution).
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
tool_call (ToolCall): The tool call
"""
message_type: Literal[MessageType.approval_request_message] = Field(
default=MessageType.approval_request_message, description="The type of the message."
)
tool_call: Union[ToolCall, ToolCallDelta] = Field(
..., description="The tool call that has been requested by the llm to run", deprecated=True
)
tool_calls: Optional[Union[List[ToolCall], ToolCallDelta]] = Field(
None, description="The tool calls that have been requested by the llm to run, which are pending approval"
)
class ApprovalResponseMessage(LettaMessage):
"""
A message representing a response form the user indicating whether a tool has been approved to run.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
approve: (bool) Whether the tool has been approved
approval_request_id: The ID of the approval request
reason: (Optional[str]) An optional explanation for the provided approval status
"""
message_type: Literal[MessageType.approval_response_message] = Field(
default=MessageType.approval_response_message, description="The type of the message."
)
approvals: Optional[List[LettaMessageReturnUnion]] = Field(default=None, description="The list of approval responses")
approve: Optional[bool] = Field(None, description="Whether the tool has been approved", deprecated=True)
approval_request_id: Optional[str] = Field(None, description="The message ID of the approval request", deprecated=True)
reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status", deprecated=True)
class AssistantMessage(LettaMessage):
"""
A message sent by the LLM in response to user input. Used in the LLM context.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
"""
message_type: Literal[MessageType.assistant_message] = Field(
default=MessageType.assistant_message, description="The type of the message."
)
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
...,
description="The message content sent by the agent (can be a string or an array of content parts)",
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
)
class LettaPing(LettaMessage):
"""
A ping message used as a keepalive to prevent SSE streams from timing out during long running requests.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["ping"] = Field(
"ping",
description="The type of the message. Ping messages are a keep-alive to prevent SSE streams from timing out during long running requests.",
)
# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
LettaMessageUnion = Annotated[
Union[
SystemMessage,
UserMessage,
ReasoningMessage,
HiddenReasoningMessage,
ToolCallMessage,
ToolReturnMessage,
AssistantMessage,
ApprovalRequestMessage,
ApprovalResponseMessage,
],
Field(discriminator="message_type"),
]
def create_letta_message_union_schema():
return {
"oneOf": [
{"$ref": "#/components/schemas/SystemMessage"},
{"$ref": "#/components/schemas/UserMessage"},
{"$ref": "#/components/schemas/ReasoningMessage"},
{"$ref": "#/components/schemas/HiddenReasoningMessage"},
{"$ref": "#/components/schemas/ToolCallMessage"},
{"$ref": "#/components/schemas/ToolReturnMessage"},
{"$ref": "#/components/schemas/AssistantMessage"},
{"$ref": "#/components/schemas/ApprovalRequestMessage"},
{"$ref": "#/components/schemas/ApprovalResponseMessage"},
],
"discriminator": {
"propertyName": "message_type",
"mapping": {
"system_message": "#/components/schemas/SystemMessage",
"user_message": "#/components/schemas/UserMessage",
"reasoning_message": "#/components/schemas/ReasoningMessage",
"hidden_reasoning_message": "#/components/schemas/HiddenReasoningMessage",
"tool_call_message": "#/components/schemas/ToolCallMessage",
"tool_return_message": "#/components/schemas/ToolReturnMessage",
"assistant_message": "#/components/schemas/AssistantMessage",
"approval_request_message": "#/components/schemas/ApprovalRequestMessage",
"approval_response_message": "#/components/schemas/ApprovalResponseMessage",
},
},
}
def create_letta_ping_schema():
return {
"properties": {
"message_type": {
"type": "string",
"const": "ping",
"title": "Message Type",
"description": "The type of the message.",
"default": "ping",
}
},
"type": "object",
"required": ["message_type"],
"title": "LettaPing",
"description": "Ping messages are a keep-alive to prevent SSE streams from timing out during long running requests.",
}
# --------------------------
# Message Update API Schemas
# --------------------------
class UpdateSystemMessage(BaseModel):
message_type: Literal["system_message"] = "system_message"
content: str = Field(
..., description="The message content sent by the system (can be a string or an array of multi-modal content parts)"
)
class UpdateUserMessage(BaseModel):
message_type: Literal["user_message"] = "user_message"
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
...,
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
)
class UpdateReasoningMessage(BaseModel):
reasoning: str
message_type: Literal["reasoning_message"] = "reasoning_message"
class UpdateAssistantMessage(BaseModel):
message_type: Literal["assistant_message"] = "assistant_message"
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
...,
description="The message content sent by the assistant (can be a string or an array of content parts)",
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
)
LettaMessageUpdateUnion = Annotated[
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
Field(discriminator="message_type"),
]
# --------------------------
# Deprecated Message Schemas
# --------------------------
class LegacyFunctionCallMessage(LettaMessage):
function_call: str
class LegacyFunctionReturn(LettaMessage):
"""
A message representing the return value of a function call (generated by Letta executing the requested function).
Args:
function_return (str): The return value of the function
status (Literal["success", "error"]): The status of the function call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
function_call_id (str): A unique identifier for the function call that generated this message
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
stderr (Optional[List(str)]): Captured stderr from the function invocation
"""
message_type: Literal["function_return"] = "function_return"
function_return: str
status: Literal["success", "error"]
function_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None
class LegacyInternalMonologue(LettaMessage):
"""
Representation of an agent's internal monologue.
Args:
internal_monologue (str): The internal monologue of the agent
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["internal_monologue"] = "internal_monologue"
internal_monologue: str
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]