Files
letta-server/letta/schemas/letta_message.py

287 lines
10 KiB
Python

import json
from datetime import datetime, timezone
from typing import Annotated, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
from letta.schemas.enums import MessageContentType
# Letta API style responses (intended to be easier to use vs getting true Message types)
class LettaMessage(BaseModel):
"""
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp.
Attributes:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
# see `message_type` attribute
id: str
date: datetime
@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
dt = dt.replace(tzinfo=timezone.utc)
# Remove microseconds since it seems like we're inconsistent with getting them
# TODO figure out why we don't always get microseconds (get_utc_time() does)
return dt.isoformat(timespec="seconds")
class MessageContent(BaseModel):
type: MessageContentType = Field(..., description="The type of the message.")
class TextContent(MessageContent):
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
text: str = Field(..., description="The text content of the message.")
MessageContentUnion = Annotated[
Union[TextContent],
Field(discriminator="type"),
]
class SystemMessage(LettaMessage):
"""
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
Attributes:
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["system_message"] = "system_message"
content: Union[str, List[MessageContentUnion]]
class UserMessage(LettaMessage):
"""
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
Attributes:
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["user_message"] = "user_message"
content: Union[str, List[MessageContentUnion]]
class ReasoningMessage(LettaMessage):
"""
Representation of an agent's internal reasoning.
Attributes:
reasoning (str): The internal reasoning of the agent
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["reasoning_message"] = "reasoning_message"
reasoning: str
class ToolCall(BaseModel):
name: str
arguments: str
tool_call_id: str
class ToolCallDelta(BaseModel):
name: Optional[str]
arguments: Optional[str]
tool_call_id: Optional[str]
# NOTE: this is a workaround to exclude None values from the JSON dump,
# since the OpenAI style of returning chunks doesn't include keys with null values
def model_dump(self, *args, **kwargs):
kwargs["exclude_none"] = True
return super().model_dump(*args, **kwargs)
def json(self, *args, **kwargs):
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)
class ToolCallMessage(LettaMessage):
"""
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
Attributes:
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["tool_call_message"] = "tool_call_message"
tool_call: Union[ToolCall, ToolCallDelta]
# NOTE: this is required for the ToolCallDelta exclude_none to work correctly
def model_dump(self, *args, **kwargs):
kwargs["exclude_none"] = True
data = super().model_dump(*args, **kwargs)
if isinstance(data["tool_call"], dict):
data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
return data
class Config:
json_encoders = {
ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
ToolCall: lambda v: v.model_dump(exclude_none=True),
}
# NOTE: this is required to cast dicts into ToolCallMessage objects
# Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None
# (instead of properly casting to ToolCallDelta instead of ToolCall)
@field_validator("tool_call", mode="before")
@classmethod
def validate_tool_call(cls, v):
if isinstance(v, dict):
if "name" in v and "arguments" in v and "tool_call_id" in v:
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
elif "name" in v or "arguments" in v or "tool_call_id" in v:
return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id"))
else:
raise ValueError("tool_call must contain either 'name' or 'arguments'")
return v
class ToolReturnMessage(LettaMessage):
"""
A message representing the return value of a tool call (generated by Letta executing the requested tool).
Attributes:
tool_return (str): The return value of the tool
status (Literal["success", "error"]): The status of the tool call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
tool_call_id (str): A unique identifier for the tool call that generated this message
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
stderr (Optional[List(str)]): Captured stderr from the tool invocation
"""
message_type: Literal["tool_return_message"] = "tool_return_message"
tool_return: str
status: Literal["success", "error"]
tool_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
class AssistantMessage(LettaMessage):
message_type: Literal["assistant_message"] = "assistant_message"
content: Union[str, List[MessageContentUnion]]
class LegacyFunctionCallMessage(LettaMessage):
function_call: str
class LegacyFunctionReturn(LettaMessage):
"""
A message representing the return value of a function call (generated by Letta executing the requested function).
Attributes:
function_return (str): The return value of the function
status (Literal["success", "error"]): The status of the function call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
function_call_id (str): A unique identifier for the function call that generated this message
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
stderr (Optional[List(str)]): Captured stderr from the function invocation
"""
message_type: Literal["function_return"] = "function_return"
function_return: str
status: Literal["success", "error"]
function_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None
class LegacyInternalMonologue(LettaMessage):
"""
Representation of an agent's internal monologue.
Attributes:
internal_monologue (str): The internal monologue of the agent
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["internal_monologue"] = "internal_monologue"
internal_monologue: str
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
LettaMessageUnion = Annotated[
Union[SystemMessage, UserMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage],
Field(discriminator="message_type"),
]
class UpdateSystemMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["system_message"] = "system_message"
class UpdateUserMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["user_message"] = "user_message"
class UpdateReasoningMessage(BaseModel):
reasoning: Union[str, List[MessageContentUnion]]
message_type: Literal["reasoning_message"] = "reasoning_message"
class UpdateAssistantMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["assistant_message"] = "assistant_message"
LettaMessageUpdateUnion = Annotated[
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
Field(discriminator="message_type"),
]
def create_letta_message_union_schema():
return {
"oneOf": [
{"$ref": "#/components/schemas/SystemMessage"},
{"$ref": "#/components/schemas/UserMessage"},
{"$ref": "#/components/schemas/ReasoningMessage"},
{"$ref": "#/components/schemas/ToolCallMessage"},
{"$ref": "#/components/schemas/ToolReturnMessage"},
{"$ref": "#/components/schemas/AssistantMessage"},
],
"discriminator": {
"propertyName": "message_type",
"mapping": {
"system_message": "#/components/schemas/SystemMessage",
"user_message": "#/components/schemas/UserMessage",
"reasoning_message": "#/components/schemas/ReasoningMessage",
"tool_call_message": "#/components/schemas/ToolCallMessage",
"tool_return_message": "#/components/schemas/ToolReturnMessage",
"assistant_message": "#/components/schemas/AssistantMessage",
},
},
}