220 lines
7.9 KiB
Python
220 lines
7.9 KiB
Python
import json
|
|
from datetime import datetime, timezone
|
|
from typing import Annotated, List, Literal, Optional, Union
|
|
|
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
|
|
# Letta API style responses (intended to be easier to use vs getting true Message types)
|
|
|
|
|
|
class LettaMessage(BaseModel):
|
|
"""
|
|
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp.
|
|
|
|
Attributes:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
|
|
"""
|
|
|
|
# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
|
# see `message_type` attribute
|
|
|
|
id: str
|
|
date: datetime
|
|
|
|
@field_serializer("date")
|
|
def serialize_datetime(self, dt: datetime, _info):
|
|
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
|
dt = dt.replace(tzinfo=timezone.utc)
|
|
# Remove microseconds since it seems like we're inconsistent with getting them
|
|
# TODO figure out why we don't always get microseconds (get_utc_time() does)
|
|
return dt.isoformat(timespec="seconds")
|
|
|
|
|
|
class SystemMessage(LettaMessage):
|
|
"""
|
|
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
|
|
|
|
Attributes:
|
|
message (str): The message sent by the system
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
"""
|
|
|
|
message_type: Literal["system_message"] = "system_message"
|
|
message: str
|
|
|
|
|
|
class UserMessage(LettaMessage):
|
|
"""
|
|
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
|
|
|
|
Attributes:
|
|
message (str): The message sent by the user
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
"""
|
|
|
|
message_type: Literal["user_message"] = "user_message"
|
|
message: str
|
|
|
|
|
|
class ReasoningMessage(LettaMessage):
|
|
"""
|
|
Representation of an agent's internal reasoning.
|
|
|
|
Attributes:
|
|
reasoning (str): The internal reasoning of the agent
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
"""
|
|
|
|
message_type: Literal["reasoning_message"] = "reasoning_message"
|
|
reasoning: str
|
|
|
|
|
|
class ToolCall(BaseModel):
|
|
|
|
name: str
|
|
arguments: str
|
|
tool_call_id: str
|
|
|
|
|
|
class ToolCallDelta(BaseModel):
|
|
|
|
name: Optional[str]
|
|
arguments: Optional[str]
|
|
tool_call_id: Optional[str]
|
|
|
|
# NOTE: this is a workaround to exclude None values from the JSON dump,
|
|
# since the OpenAI style of returning chunks doesn't include keys with null values
|
|
def model_dump(self, *args, **kwargs):
|
|
kwargs["exclude_none"] = True
|
|
return super().model_dump(*args, **kwargs)
|
|
|
|
def json(self, *args, **kwargs):
|
|
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)
|
|
|
|
|
|
class ToolCallMessage(LettaMessage):
|
|
"""
|
|
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
|
|
|
|
Attributes:
|
|
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
"""
|
|
|
|
message_type: Literal["tool_call_message"] = "tool_call_message"
|
|
tool_call: Union[ToolCall, ToolCallDelta]
|
|
|
|
# NOTE: this is required for the ToolCallDelta exclude_none to work correctly
|
|
def model_dump(self, *args, **kwargs):
|
|
kwargs["exclude_none"] = True
|
|
data = super().model_dump(*args, **kwargs)
|
|
if isinstance(data["tool_call"], dict):
|
|
data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
|
|
return data
|
|
|
|
class Config:
|
|
json_encoders = {
|
|
ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
|
|
ToolCall: lambda v: v.model_dump(exclude_none=True),
|
|
}
|
|
|
|
# NOTE: this is required to cast dicts into ToolCallMessage objects
|
|
# Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None
|
|
# (instead of properly casting to ToolCallDelta instead of ToolCall)
|
|
@field_validator("tool_call", mode="before")
|
|
@classmethod
|
|
def validate_tool_call(cls, v):
|
|
if isinstance(v, dict):
|
|
if "name" in v and "arguments" in v and "tool_call_id" in v:
|
|
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
|
|
elif "name" in v or "arguments" in v or "tool_call_id" in v:
|
|
return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id"))
|
|
else:
|
|
raise ValueError("tool_call must contain either 'name' or 'arguments'")
|
|
return v
|
|
|
|
|
|
class ToolReturnMessage(LettaMessage):
|
|
"""
|
|
A message representing the return value of a tool call (generated by Letta executing the requested tool).
|
|
|
|
Attributes:
|
|
tool_return (str): The return value of the tool
|
|
status (Literal["success", "error"]): The status of the tool call
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
tool_call_id (str): A unique identifier for the tool call that generated this message
|
|
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
|
|
stderr (Optional[List(str)]): Captured stderr from the tool invocation
|
|
"""
|
|
|
|
message_type: Literal["tool_return_message"] = "tool_return_message"
|
|
tool_return: str
|
|
status: Literal["success", "error"]
|
|
tool_call_id: str
|
|
stdout: Optional[List[str]] = None
|
|
stderr: Optional[List[str]] = None
|
|
|
|
|
|
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
|
|
|
|
|
|
class AssistantMessage(LettaMessage):
|
|
message_type: Literal["assistant_message"] = "assistant_message"
|
|
assistant_message: str
|
|
|
|
|
|
class LegacyFunctionCallMessage(LettaMessage):
|
|
function_call: str
|
|
|
|
|
|
class LegacyFunctionReturn(LettaMessage):
|
|
"""
|
|
A message representing the return value of a function call (generated by Letta executing the requested function).
|
|
|
|
Attributes:
|
|
function_return (str): The return value of the function
|
|
status (Literal["success", "error"]): The status of the function call
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
function_call_id (str): A unique identifier for the function call that generated this message
|
|
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
|
|
stderr (Optional[List(str)]): Captured stderr from the function invocation
|
|
"""
|
|
|
|
message_type: Literal["function_return"] = "function_return"
|
|
function_return: str
|
|
status: Literal["success", "error"]
|
|
function_call_id: str
|
|
stdout: Optional[List[str]] = None
|
|
stderr: Optional[List[str]] = None
|
|
|
|
|
|
class LegacyInternalMonologue(LettaMessage):
|
|
"""
|
|
Representation of an agent's internal monologue.
|
|
|
|
Attributes:
|
|
internal_monologue (str): The internal monologue of the agent
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
"""
|
|
|
|
message_type: Literal["internal_monologue"] = "internal_monologue"
|
|
internal_monologue: str
|
|
|
|
|
|
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
|
|
|
|
|
|
LettaMessageUnion = Annotated[
|
|
Union[SystemMessage, UserMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage],
|
|
Field(discriminator="message_type"),
|
|
]
|