From e29f333cbe9c302fb789ca4e695f758fcba12490 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 13 Mar 2025 12:04:03 -0700 Subject: [PATCH] chore: message schema api improvements (#1267) --- letta/agent.py | 7 +- letta/agents/ephemeral_agent.py | 3 +- letta/dynamic_multi_agent.py | 2 +- letta/functions/function_sets/extras.py | 3 +- letta/memory.py | 3 +- letta/orm/message.py | 2 +- letta/round_robin_multi_agent.py | 2 +- letta/schemas/enums.py | 4 - letta/schemas/letta_message.py | 269 ++++++++++-------- letta/schemas/letta_message_content.py | 56 ++++ letta/schemas/message.py | 21 +- letta/server/rest_api/app.py | 2 + letta/server/rest_api/routers/v1/agents.py | 2 +- letta/server/rest_api/utils.py | 2 +- letta/server/server.py | 3 +- .../services/helpers/agent_manager_helper.py | 3 +- letta/supervisor_multi_agent.py | 2 +- tests/integration_test_summarizer.py | 3 +- tests/test_static_buffer_summarize.py | 2 +- 19 files changed, 242 insertions(+), 149 deletions(-) create mode 100644 letta/schemas/letta_message_content.py diff --git a/letta/agent.py b/letta/agent.py index daa9a6c8..11414bce 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -39,7 +39,8 @@ from letta.orm.enums import ToolType from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageContentType, MessageRole +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, ToolReturn from letta.schemas.openai.chat_completion_response import ChatCompletionResponse @@ -165,7 +166,7 @@ class Agent(BaseAgent): in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) for i in range(len(in_context_messages) - 1, -1, -1): msg = in_context_messages[i] - if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and msg.content[0].type == MessageContentType.text: + if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent): text_content = msg.content[0].text try: response_json = json.loads(text_content) @@ -1210,7 +1211,7 @@ class Agent(BaseAgent): and in_context_messages[1].role == MessageRole.user and in_context_messages[1].content and len(in_context_messages[1].content) == 1 - and in_context_messages[1].content[0].type == MessageContentType.text + and isinstance(in_context_messages[1].content[0], TextContent) # TODO remove hardcoding and "The following is a summary of the previous " in in_context_messages[1].content[0].text ): diff --git a/letta/agents/ephemeral_agent.py b/letta/agents/ephemeral_agent.py index e12d78b1..91e89343 100644 --- a/letta/agents/ephemeral_agent.py +++ b/letta/agents/ephemeral_agent.py @@ -5,7 +5,8 @@ import openai from letta.agents.base_agent import BaseAgent from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole -from letta.schemas.letta_message import TextContent, UserMessage +from letta.schemas.letta_message import UserMessage +from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.user import User diff --git a/letta/dynamic_multi_agent.py b/letta/dynamic_multi_agent.py index 4b979ef8..93599324 100644 --- a/letta/dynamic_multi_agent.py +++ b/letta/dynamic_multi_agent.py @@ -4,7 +4,7 @@ from letta.agent import Agent, AgentState from letta.interface import AgentInterface from letta.orm import User from letta.schemas.block import Block -from letta.schemas.letta_message import TextContent +from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.usage import LettaUsageStatistics diff --git a/letta/functions/function_sets/extras.py b/letta/functions/function_sets/extras.py index 8169b593..4c91af76 100644 --- a/letta/functions/function_sets/extras.py +++ b/letta/functions/function_sets/extras.py @@ -7,7 +7,8 @@ import requests from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE from letta.helpers.json_helpers import json_dumps, json_loads from letta.llm_api.llm_api_tools import create -from letta.schemas.message import Message, TextContent +from letta.schemas.letta_message_content import TextContent +from letta.schemas.message import Message def message_chatgpt(self, message: str): diff --git a/letta/memory.py b/letta/memory.py index c260eda1..7f68b007 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -5,8 +5,9 @@ from letta.llm_api.llm_api_tools import create from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole +from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import Memory -from letta.schemas.message import Message, TextContent +from letta.schemas.message import Message from letta.settings import summarizer_settings from letta.utils import count_tokens, printd diff --git a/letta/orm/message.py b/letta/orm/message.py index 145642aa..64660d41 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -7,8 +7,8 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.custom_columns import ToolCallColumn, ToolReturnColumn from letta.orm.mixins import AgentMixin, OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.letta_message_content import TextContent as PydanticTextContent from letta.schemas.message import Message as PydanticMessage -from letta.schemas.message import TextContent as PydanticTextContent from letta.schemas.message import ToolReturn diff --git a/letta/round_robin_multi_agent.py b/letta/round_robin_multi_agent.py index bca50f40..1796b882 100644 --- a/letta/round_robin_multi_agent.py +++ b/letta/round_robin_multi_agent.py @@ -3,7 +3,7 @@ from typing import List, Optional from letta.agent import Agent, AgentState from letta.interface import AgentInterface from letta.orm import User -from letta.schemas.letta_message import TextContent +from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.usage import LettaUsageStatistics diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 0b396d5d..1852aa5d 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -9,10 +9,6 @@ class MessageRole(str, Enum): system = "system" -class MessageContentType(str, Enum): - text = "text" - - class OptionState(str, Enum): """Useful for kwargs that are bool + default option""" diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index 305420e2..704e1298 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -4,87 +4,83 @@ from typing import Annotated, List, Literal, Optional, Union from pydantic import BaseModel, Field, field_serializer, field_validator -from letta.schemas.enums import MessageContentType +from letta.schemas.letta_message_content import LettaMessageContentUnion, get_letta_message_content_union_str_json_schema -# Letta API style responses (intended to be easier to use vs getting true Message types) +# --------------------------- +# Letta API Messaging Schemas +# --------------------------- class LettaMessage(BaseModel): """ - Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp. + Base class for simplified Letta message response type. This is intended to be used for developers + who want the internal monologue, tool calls, and tool returns in a simplified format that does not + include additional information other than the content and timestamp. - Attributes: + Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format """ - # NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions - # see `message_type` attribute - id: str date: datetime @field_serializer("date") def serialize_datetime(self, dt: datetime, _info): + """ + Remove microseconds since it seems like we're inconsistent with getting them + TODO: figure out why we don't always get microseconds (get_utc_time() does) + """ if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: dt = dt.replace(tzinfo=timezone.utc) - # Remove microseconds since it seems like we're inconsistent with getting them - # TODO figure out why we don't always get microseconds (get_utc_time() does) return dt.isoformat(timespec="seconds") -class MessageContent(BaseModel): - type: MessageContentType = Field(..., description="The type of the message.") - - -class TextContent(MessageContent): - type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.") - text: str = Field(..., description="The text content of the message.") - - -MessageContentUnion = Annotated[ - Union[TextContent], - Field(discriminator="type"), -] - - class SystemMessage(LettaMessage): """ A message generated by the system. Never streamed back on a response, only used for cursor pagination. - Attributes: - content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts) + Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format + content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the system (can be a string or an array of content parts) """ message_type: Literal["system_message"] = "system_message" - content: Union[str, List[MessageContentUnion]] + content: Union[str, List[LettaMessageContentUnion]] = Field( + ..., + description="The message content sent by the system (can be a string or an array of content parts)", + json_schema_extra=get_letta_message_content_union_str_json_schema(), + ) class UserMessage(LettaMessage): """ A message sent by the user. Never streamed back on a response, only used for cursor pagination. - Attributes: - content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts) + Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format + content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts) """ message_type: Literal["user_message"] = "user_message" - content: Union[str, List[MessageContentUnion]] + content: Union[str, List[LettaMessageContentUnion]] = Field( + ..., + description="The message content sent by the user (can be a string or an array of content parts)", + json_schema_extra=get_letta_message_content_union_str_json_schema(), + ) class ReasoningMessage(LettaMessage): """ Representation of an agent's internal reasoning. - Attributes: - reasoning (str): The internal reasoning of the agent + Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format + reasoning (str): The internal reasoning of the agent """ message_type: Literal["reasoning_message"] = "reasoning_message" @@ -92,21 +88,21 @@ class ReasoningMessage(LettaMessage): class ToolCall(BaseModel): - name: str arguments: str tool_call_id: str class ToolCallDelta(BaseModel): - name: Optional[str] arguments: Optional[str] tool_call_id: Optional[str] - # NOTE: this is a workaround to exclude None values from the JSON dump, - # since the OpenAI style of returning chunks doesn't include keys with null values def model_dump(self, *args, **kwargs): + """ + This is a workaround to exclude None values from the JSON dump since the + OpenAI style of returning chunks doesn't include keys with null values. + """ kwargs["exclude_none"] = True return super().model_dump(*args, **kwargs) @@ -118,17 +114,19 @@ class ToolCallMessage(LettaMessage): """ A message representing a request to call a tool (generated by the LLM to trigger tool execution). - Attributes: - tool_call (Union[ToolCall, ToolCallDelta]): The tool call + Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format + tool_call (Union[ToolCall, ToolCallDelta]): The tool call """ message_type: Literal["tool_call_message"] = "tool_call_message" tool_call: Union[ToolCall, ToolCallDelta] - # NOTE: this is required for the ToolCallDelta exclude_none to work correctly def model_dump(self, *args, **kwargs): + """ + Handling for the ToolCallDelta exclude_none to work correctly + """ kwargs["exclude_none"] = True data = super().model_dump(*args, **kwargs) if isinstance(data["tool_call"], dict): @@ -141,12 +139,14 @@ class ToolCallMessage(LettaMessage): ToolCall: lambda v: v.model_dump(exclude_none=True), } - # NOTE: this is required to cast dicts into ToolCallMessage objects - # Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None - # (instead of properly casting to ToolCallDelta instead of ToolCall) @field_validator("tool_call", mode="before") @classmethod def validate_tool_call(cls, v): + """ + Casts dicts into ToolCallMessage objects. Without this extra validator, Pydantic will throw + an error if 'name' or 'arguments' are None instead of properly casting to ToolCallDelta + instead of ToolCall. + """ if isinstance(v, dict): if "name" in v and "arguments" in v and "tool_call_id" in v: return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"]) @@ -161,11 +161,11 @@ class ToolReturnMessage(LettaMessage): """ A message representing the return value of a tool call (generated by Letta executing the requested tool). - Attributes: - tool_return (str): The return value of the tool - status (Literal["success", "error"]): The status of the tool call + Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format + tool_return (str): The return value of the tool + status (Literal["success", "error"]): The status of the tool call tool_call_id (str): A unique identifier for the tool call that generated this message stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation stderr (Optional[List(str)]): Captured stderr from the tool invocation @@ -179,89 +179,31 @@ class ToolReturnMessage(LettaMessage): stderr: Optional[List[str]] = None -# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string - - class AssistantMessage(LettaMessage): + """ + A message sent by the LLM in response to user input. Used in the LLM context. + + Args: + id (str): The ID of the message + date (datetime): The date the message was created in ISO format + content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts) + """ + message_type: Literal["assistant_message"] = "assistant_message" - content: Union[str, List[MessageContentUnion]] - - -class LegacyFunctionCallMessage(LettaMessage): - function_call: str - - -class LegacyFunctionReturn(LettaMessage): - """ - A message representing the return value of a function call (generated by Letta executing the requested function). - - Attributes: - function_return (str): The return value of the function - status (Literal["success", "error"]): The status of the function call - id (str): The ID of the message - date (datetime): The date the message was created in ISO format - function_call_id (str): A unique identifier for the function call that generated this message - stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation - stderr (Optional[List(str)]): Captured stderr from the function invocation - """ - - message_type: Literal["function_return"] = "function_return" - function_return: str - status: Literal["success", "error"] - function_call_id: str - stdout: Optional[List[str]] = None - stderr: Optional[List[str]] = None - - -class LegacyInternalMonologue(LettaMessage): - """ - Representation of an agent's internal monologue. - - Attributes: - internal_monologue (str): The internal monologue of the agent - id (str): The ID of the message - date (datetime): The date the message was created in ISO format - """ - - message_type: Literal["internal_monologue"] = "internal_monologue" - internal_monologue: str - - -LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn] + content: Union[str, List[LettaMessageContentUnion]] = Field( + ..., + description="The message content sent by the agent (can be a string or an array of content parts)", + json_schema_extra=get_letta_message_content_union_str_json_schema(), + ) +# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions LettaMessageUnion = Annotated[ Union[SystemMessage, UserMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage], Field(discriminator="message_type"), ] -class UpdateSystemMessage(BaseModel): - content: Union[str, List[MessageContentUnion]] - message_type: Literal["system_message"] = "system_message" - - -class UpdateUserMessage(BaseModel): - content: Union[str, List[MessageContentUnion]] - message_type: Literal["user_message"] = "user_message" - - -class UpdateReasoningMessage(BaseModel): - reasoning: Union[str, List[MessageContentUnion]] - message_type: Literal["reasoning_message"] = "reasoning_message" - - -class UpdateAssistantMessage(BaseModel): - content: Union[str, List[MessageContentUnion]] - message_type: Literal["assistant_message"] = "assistant_message" - - -LettaMessageUpdateUnion = Annotated[ - Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage], - Field(discriminator="message_type"), -] - - def create_letta_message_union_schema(): return { "oneOf": [ @@ -284,3 +226,94 @@ def create_letta_message_union_schema(): }, }, } + + +# -------------------------- +# Message Update API Schemas +# -------------------------- + + +class UpdateSystemMessage(BaseModel): + message_type: Literal["system_message"] = "system_message" + content: Union[str, List[LettaMessageContentUnion]] = Field( + ..., + description="The message content sent by the system (can be a string or an array of content parts)", + json_schema_extra=get_letta_message_content_union_str_json_schema(), + ) + + +class UpdateUserMessage(BaseModel): + message_type: Literal["user_message"] = "user_message" + content: Union[str, List[LettaMessageContentUnion]] = Field( + ..., + description="The message content sent by the user (can be a string or an array of content parts)", + json_schema_extra=get_letta_message_content_union_str_json_schema(), + ) + + +class UpdateReasoningMessage(BaseModel): + reasoning: str + message_type: Literal["reasoning_message"] = "reasoning_message" + + +class UpdateAssistantMessage(BaseModel): + message_type: Literal["assistant_message"] = "assistant_message" + content: Union[str, List[LettaMessageContentUnion]] = Field( + ..., + description="The message content sent by the assistant (can be a string or an array of content parts)", + json_schema_extra=get_letta_message_content_union_str_json_schema(), + ) + + +LettaMessageUpdateUnion = Annotated[ + Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage], + Field(discriminator="message_type"), +] + + +# -------------------------- +# Deprecated Message Schemas +# -------------------------- + + +class LegacyFunctionCallMessage(LettaMessage): + function_call: str + + +class LegacyFunctionReturn(LettaMessage): + """ + A message representing the return value of a function call (generated by Letta executing the requested function). + + Args: + function_return (str): The return value of the function + status (Literal["success", "error"]): The status of the function call + id (str): The ID of the message + date (datetime): The date the message was created in ISO format + function_call_id (str): A unique identifier for the function call that generated this message + stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation + stderr (Optional[List(str)]): Captured stderr from the function invocation + """ + + message_type: Literal["function_return"] = "function_return" + function_return: str + status: Literal["success", "error"] + function_call_id: str + stdout: Optional[List[str]] = None + stderr: Optional[List[str]] = None + + +class LegacyInternalMonologue(LettaMessage): + """ + Representation of an agent's internal monologue. + + Args: + internal_monologue (str): The internal monologue of the agent + id (str): The ID of the message + date (datetime): The date the message was created in ISO format + """ + + message_type: Literal["internal_monologue"] = "internal_monologue" + internal_monologue: str + + +LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn] diff --git a/letta/schemas/letta_message_content.py b/letta/schemas/letta_message_content.py new file mode 100644 index 00000000..2c27f492 --- /dev/null +++ b/letta/schemas/letta_message_content.py @@ -0,0 +1,56 @@ +from enum import Enum +from typing import Annotated, Literal, Union + +from pydantic import BaseModel, Field + + +class MessageContentType(str, Enum): + text = "text" + + +class MessageContent(BaseModel): + type: MessageContentType = Field(..., description="The type of the message.") + + +# ------------------------------- +# Multi-Modal Content Types +# ------------------------------- + + +class TextContent(MessageContent): + type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.") + text: str = Field(..., description="The text content of the message.") + + +LettaMessageContentUnion = Annotated[ + Union[TextContent], + Field(discriminator="type"), +] + + +def create_letta_message_content_union_schema(): + return { + "oneOf": [ + {"$ref": "#/components/schemas/TextContent"}, + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "text": "#/components/schemas/TextContent", + }, + }, + } + + +def get_letta_message_content_union_str_json_schema(): + return { + "anyOf": [ + { + "type": "array", + "items": { + "$ref": "#/components/schemas/LettaMessageContentUnion", + }, + }, + {"type": "string"}, + ], + } diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 18f715b8..8190f60d 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -15,20 +15,19 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TO from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime from letta.helpers.json_helpers import json_dumps from letta.local_llm.constants import INNER_THOUGHTS_KWARG -from letta.schemas.enums import MessageContentType, MessageRole +from letta.schemas.enums import MessageRole from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import ( AssistantMessage, LettaMessage, - MessageContentUnion, ReasoningMessage, SystemMessage, - TextContent, ToolCall, ToolCallMessage, ToolReturnMessage, UserMessage, ) +from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent from letta.system import unpack_message @@ -66,7 +65,7 @@ class MessageCreate(BaseModel): MessageRole.user, MessageRole.system, ] = Field(..., description="The role of the participant.") - content: Union[str, List[MessageContentUnion]] = Field(..., description="The content of the message.") + content: Union[str, List[LettaMessageContentUnion]] = Field(..., description="The content of the message.") name: Optional[str] = Field(None, description="The name of the participant.") @@ -74,7 +73,7 @@ class MessageUpdate(BaseModel): """Request to update a message""" role: Optional[MessageRole] = Field(None, description="The role of the participant.") - content: Optional[Union[str, List[MessageContentUnion]]] = Field(None, description="The content of the message.") + content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(None, description="The content of the message.") # NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message) # user_id: Optional[str] = Field(None, description="The unique identifier of the user.") # agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") @@ -119,7 +118,7 @@ class Message(BaseMessage): id: str = BaseMessage.generate_id_field() role: MessageRole = Field(..., description="The role of the participant.") - content: Optional[List[MessageContentUnion]] = Field(None, description="The content of the message.") + content: Optional[List[LettaMessageContentUnion]] = Field(None, description="The content of the message.") organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.") agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") model: Optional[str] = Field(None, description="The model used to make the function call.") @@ -215,7 +214,7 @@ class Message(BaseMessage): assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, ) -> List[LettaMessage]: """Convert message object (in DB format) to the style used by the original Letta API""" - if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text: + if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text else: text_content = None @@ -486,7 +485,7 @@ class Message(BaseMessage): """Go from Message class to ChatCompletion message object""" # TODO change to pydantic casting, eg `return SystemMessageModel(self)` - if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text: + if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text else: text_content = None @@ -561,7 +560,7 @@ class Message(BaseMessage): Args: inner_thoughts_xml_tag (str): The XML tag to wrap around inner thoughts """ - if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text: + if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text else: text_content = None @@ -656,7 +655,7 @@ class Message(BaseMessage): # type Content: https://ai.google.dev/api/rest/v1/Content / https://ai.google.dev/api/rest/v1beta/Content # parts[]: Part # role: str ('user' or 'model') - if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text: + if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text else: text_content = None @@ -782,7 +781,7 @@ class Message(BaseMessage): # TODO: update this prompt style once guidance from Cohere on # embedded function calls in multi-turn conversation become more clear - if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text: + if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent): text_content = self.content[0].text else: text_content = None diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 124eecf4..d249a3a2 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -17,6 +17,7 @@ from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaU from letta.log import get_logger from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError from letta.schemas.letta_message import create_letta_message_union_schema +from letta.schemas.letta_message_content import create_letta_message_content_union_schema from letta.server.constants import REST_DEFAULT_PORT # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests @@ -68,6 +69,7 @@ def generate_openapi_schema(app: FastAPI): letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")} letta_docs["info"]["title"] = "Letta API" letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema() + letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema() for name, docs in [ ( "letta", diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 0c0afa51..fc91b6fa 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -530,7 +530,7 @@ def list_messages( ) -@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUpdateUnion, operation_id="modify_message") +@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_message") def modify_message( agent_id: str, message_id: str, diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index b349e32c..bbe2df32 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -18,7 +18,7 @@ from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger from letta.schemas.enums import MessageRole -from letta.schemas.letta_message import TextContent +from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User diff --git a/letta/server/server.py b/letta/server/server.py index 1eebe729..00a80a47 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -49,10 +49,11 @@ from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate from letta.schemas.group import Group, ManagerType from letta.schemas.job import Job, JobUpdate from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage +from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary -from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate, TextContent +from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate from letta.schemas.organization import Organization from letta.schemas.passage import Passage, PassageUpdate from letta.schemas.providers import ( diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index e1d3c91a..a87af65f 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -11,8 +11,9 @@ from letta.orm.errors import NoResultFound from letta.prompts import gpt_system from letta.schemas.agent import AgentState, AgentType from letta.schemas.enums import MessageRole +from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import Memory -from letta.schemas.message import Message, MessageCreate, TextContent +from letta.schemas.message import Message, MessageCreate from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.tool_rule import ToolRule from letta.schemas.user import User diff --git a/letta/supervisor_multi_agent.py b/letta/supervisor_multi_agent.py index 55f481ad..57991700 100644 --- a/letta/supervisor_multi_agent.py +++ b/letta/supervisor_multi_agent.py @@ -6,7 +6,7 @@ from letta.functions.function_sets.multi_agent import send_message_to_all_agents from letta.interface import AgentInterface from letta.orm import User from letta.orm.enums import ToolType -from letta.schemas.letta_message import TextContent +from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from letta.schemas.usage import LettaUsageStatistics diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 3c34b0a0..87c63245 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -13,8 +13,9 @@ from letta.errors import ContextWindowExceededError from letta.llm_api.helpers import calculate_summarizer_cutoff from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole +from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message, TextContent +from letta.schemas.message import Message from letta.settings import summarizer_settings from letta.streaming_interface import StreamingRefreshCLIInterface from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH diff --git a/tests/test_static_buffer_summarize.py b/tests/test_static_buffer_summarize.py index 0fa18582..9feb4137 100644 --- a/tests/test_static_buffer_summarize.py +++ b/tests/test_static_buffer_summarize.py @@ -6,7 +6,7 @@ import pytest from letta.agents.base_agent import BaseAgent from letta.schemas.enums import MessageRole -from letta.schemas.letta_message import TextContent +from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message from letta.services.summarizer.enums import SummarizationMode from letta.services.summarizer.summarizer import Summarizer