From d8966d8c7ee157984cc4002c4b6e45cd948675d3 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 13 Mar 2025 18:43:32 -0700 Subject: [PATCH] feat: add content parts to message schema (#1273) Co-authored-by: Matt Zhou --- ...ceb07c2384_add_content_parts_to_message.py | 32 ++++ letta/helpers/converters.py | 64 ++++++++ letta/llm_api/openai.py | 2 +- letta/orm/custom_columns.py | 15 ++ letta/orm/message.py | 8 +- letta/schemas/letta_message.py | 76 ++++++--- letta/schemas/letta_message_content.py | 144 +++++++++++++++++- letta/schemas/message.py | 48 +++--- .../marshmallow_custom_fields.py | 12 ++ .../pydantic_agent_schema.py | 3 +- letta/server/rest_api/app.py | 9 +- letta/server/server.py | 2 +- letta/services/message_manager.py | 14 +- tests/test_agent_serialization.py | 4 +- tests/test_managers.py | 33 ++-- 15 files changed, 385 insertions(+), 81 deletions(-) create mode 100644 alembic/versions/2cceb07c2384_add_content_parts_to_message.py diff --git a/alembic/versions/2cceb07c2384_add_content_parts_to_message.py b/alembic/versions/2cceb07c2384_add_content_parts_to_message.py new file mode 100644 index 00000000..1914418c --- /dev/null +++ b/alembic/versions/2cceb07c2384_add_content_parts_to_message.py @@ -0,0 +1,32 @@ +"""add content parts to message + +Revision ID: 2cceb07c2384 +Revises: 77de976590ae +Create Date: 2025-03-13 14:30:53.177061 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op +from letta.orm.custom_columns import MessageContentColumn + +# revision identifiers, used by Alembic. +revision: str = "2cceb07c2384" +down_revision: Union[str, None] = "77de976590ae" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("messages", sa.Column("content", MessageContentColumn(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("messages", "content") + # ### end Alembic commands ### diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index f35ec69f..73d1196f 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -8,6 +8,16 @@ from sqlalchemy import Dialect from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ToolRuleType +from letta.schemas.letta_message_content import ( + MessageContent, + MessageContentType, + OmittedReasoningContent, + ReasoningContent, + RedactedReasoningContent, + TextContent, + ToolCallContent, + ToolReturnContent, +) from letta.schemas.llm_config import LLMConfig from letta.schemas.message import ToolReturn from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule @@ -166,6 +176,60 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]: return tool_returns +# ---------------------------- +# MessageContent Serialization +# ---------------------------- + + +def serialize_message_content(message_content: Optional[List[Union[MessageContent, dict]]]) -> List[Dict]: + """Convert a list of MessageContent objects into JSON-serializable format.""" + if not message_content: + return [] + + serialized_message_content = [] + for content in message_content: + if isinstance(content, MessageContent): + serialized_message_content.append(content.model_dump()) + elif isinstance(content, dict): + serialized_message_content.append(content) # Already a dictionary, leave it as-is + else: + raise TypeError(f"Unexpected message content type: {type(content)}") + + return serialized_message_content + + +def deserialize_message_content(data: Optional[List[Dict]]) -> List[MessageContent]: + """Convert a JSON list back into MessageContent objects.""" + if not data: + return [] + + message_content = [] + for item in data: + if not item: + continue + + content_type = item.get("type") + if content_type == MessageContentType.text: + content = TextContent(**item) + elif content_type == MessageContentType.tool_call: + content = ToolCallContent(**item) + elif content_type == MessageContentType.tool_return: + content = ToolReturnContent(**item) + elif content_type == MessageContentType.reasoning: + content = ReasoningContent(**item) + elif content_type == MessageContentType.redacted_reasoning: + content = RedactedReasoningContent(**item) + elif content_type == MessageContentType.omitted_reasoning: + content = OmittedReasoningContent(**item) + else: + # Skip invalid content + continue + + message_content.append(content) + + return message_content + + # -------------------------- # Vector Serialization # -------------------------- diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index d8ad521b..948730b6 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -221,7 +221,7 @@ def openai_chat_completions_process_stream( # TODO(sarah): add message ID generation function dummy_message = _Message( role=_MessageRole.assistant, - text="", + content=[], agent_id="", model="", name=None, diff --git a/letta/orm/custom_columns.py b/letta/orm/custom_columns.py index ff8d133c..2f9150d5 100644 --- a/letta/orm/custom_columns.py +++ b/letta/orm/custom_columns.py @@ -4,12 +4,14 @@ from sqlalchemy.types import BINARY, TypeDecorator from letta.helpers.converters import ( deserialize_embedding_config, deserialize_llm_config, + deserialize_message_content, deserialize_tool_calls, deserialize_tool_returns, deserialize_tool_rules, deserialize_vector, serialize_embedding_config, serialize_llm_config, + serialize_message_content, serialize_tool_calls, serialize_tool_returns, serialize_tool_rules, @@ -82,6 +84,19 @@ class ToolReturnColumn(TypeDecorator): return deserialize_tool_returns(value) +class MessageContentColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing the content parts of a message as JSON.""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + return serialize_message_content(value) + + def process_result_value(self, value, dialect): + return deserialize_message_content(value) + + class CommonVector(TypeDecorator): """Custom SQLAlchemy column type for storing vectors in SQLite.""" diff --git a/letta/orm/message.py b/letta/orm/message.py index 64660d41..d8ee5692 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -4,9 +4,10 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe from sqlalchemy import ForeignKey, Index from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.orm.custom_columns import ToolCallColumn, ToolReturnColumn +from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn from letta.orm.mixins import AgentMixin, OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.letta_message_content import MessageContent from letta.schemas.letta_message_content import TextContent as PydanticTextContent from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import ToolReturn @@ -25,6 +26,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): id: Mapped[str] = mapped_column(primary_key=True, doc="Unique message identifier") role: Mapped[str] = mapped_column(doc="Message role (user/assistant/system/tool)") text: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Message content") + content: Mapped[List[MessageContent]] = mapped_column(MessageContentColumn, nullable=True, doc="Message content parts") model: Mapped[Optional[str]] = mapped_column(nullable=True, doc="LLM model used") name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios") tool_calls: Mapped[List[OpenAIToolCall]] = mapped_column(ToolCallColumn, doc="Tool call information") @@ -54,8 +56,8 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): return self.job_message.job if self.job_message else None def to_pydantic(self) -> PydanticMessage: - """custom pydantic conversion for message content mapping""" + """Custom pydantic conversion to handle data using legacy text field""" model = self.__pydantic_model__.model_validate(self) - if self.text: + if self.text and not model.content: model.content = [PydanticTextContent(text=self.text)] return model diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index 704e1298..c10c34da 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -4,7 +4,12 @@ from typing import Annotated, List, Literal, Optional, Union from pydantic import BaseModel, Field, field_serializer, field_validator -from letta.schemas.letta_message_content import LettaMessageContentUnion, get_letta_message_content_union_str_json_schema +from letta.schemas.letta_message_content import ( + LettaAssistantMessageContentUnion, + LettaUserMessageContentUnion, + get_letta_assistant_message_content_union_str_json_schema, + get_letta_user_message_content_union_str_json_schema, +) # --------------------------- # Letta API Messaging Schemas @@ -20,11 +25,12 @@ class LettaMessage(BaseModel): Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format - + name (Optional[str]): The name of the sender of the message """ id: str date: datetime + name: Optional[str] = None @field_serializer("date") def serialize_datetime(self, dt: datetime, _info): @@ -44,15 +50,12 @@ class SystemMessage(LettaMessage): Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format - content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the system (can be a string or an array of content parts) + name (Optional[str]): The name of the sender of the message + content (str): The message content sent by the system """ message_type: Literal["system_message"] = "system_message" - content: Union[str, List[LettaMessageContentUnion]] = Field( - ..., - description="The message content sent by the system (can be a string or an array of content parts)", - json_schema_extra=get_letta_message_content_union_str_json_schema(), - ) + content: str = Field(..., description="The message content sent by the system") class UserMessage(LettaMessage): @@ -62,14 +65,15 @@ class UserMessage(LettaMessage): Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format - content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts) + name (Optional[str]): The name of the sender of the message + content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts) """ message_type: Literal["user_message"] = "user_message" - content: Union[str, List[LettaMessageContentUnion]] = Field( + content: Union[str, List[LettaUserMessageContentUnion]] = Field( ..., - description="The message content sent by the user (can be a string or an array of content parts)", - json_schema_extra=get_letta_message_content_union_str_json_schema(), + description="The message content sent by the user (can be a string or an array of multi-modal content parts)", + json_schema_extra=get_letta_user_message_content_union_str_json_schema(), ) @@ -80,10 +84,33 @@ class ReasoningMessage(LettaMessage): Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format + name (Optional[str]): The name of the sender of the message + source (Literal["reasoner_model", "non_reasoner_model"]): Whether the reasoning + content was generated natively by a reasoner model or derived via prompting reasoning (str): The internal reasoning of the agent """ message_type: Literal["reasoning_message"] = "reasoning_message" + source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model" + reasoning: str + + +class HiddenReasoningMessage(LettaMessage): + """ + Representation of an agent's internal reasoning where reasoning content + has been hidden from the response. + + Args: + id (str): The ID of the message + date (datetime): The date the message was created in ISO format + name (Optional[str]): The name of the sender of the message + state (Literal["redacted", "omitted"]): Whether the reasoning + content was redacted by the provider or simply omitted by the API + reasoning (str): The internal reasoning of the agent + """ + + message_type: Literal["reasoning_message"] = "reasoning_message" + state: Literal["redacted", "omitted"] reasoning: str @@ -117,6 +144,7 @@ class ToolCallMessage(LettaMessage): Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format + name (Optional[str]): The name of the sender of the message tool_call (Union[ToolCall, ToolCallDelta]): The tool call """ @@ -164,6 +192,7 @@ class ToolReturnMessage(LettaMessage): Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format + name (Optional[str]): The name of the sender of the message tool_return (str): The return value of the tool status (Literal["success", "error"]): The status of the tool call tool_call_id (str): A unique identifier for the tool call that generated this message @@ -186,14 +215,15 @@ class AssistantMessage(LettaMessage): Args: id (str): The ID of the message date (datetime): The date the message was created in ISO format - content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts) + name (Optional[str]): The name of the sender of the message + content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts) """ message_type: Literal["assistant_message"] = "assistant_message" - content: Union[str, List[LettaMessageContentUnion]] = Field( + content: Union[str, List[LettaAssistantMessageContentUnion]] = Field( ..., description="The message content sent by the agent (can be a string or an array of content parts)", - json_schema_extra=get_letta_message_content_union_str_json_schema(), + json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(), ) @@ -235,19 +265,17 @@ def create_letta_message_union_schema(): class UpdateSystemMessage(BaseModel): message_type: Literal["system_message"] = "system_message" - content: Union[str, List[LettaMessageContentUnion]] = Field( - ..., - description="The message content sent by the system (can be a string or an array of content parts)", - json_schema_extra=get_letta_message_content_union_str_json_schema(), + content: str = Field( + ..., description="The message content sent by the system (can be a string or an array of multi-modal content parts)" ) class UpdateUserMessage(BaseModel): message_type: Literal["user_message"] = "user_message" - content: Union[str, List[LettaMessageContentUnion]] = Field( + content: Union[str, List[LettaUserMessageContentUnion]] = Field( ..., - description="The message content sent by the user (can be a string or an array of content parts)", - json_schema_extra=get_letta_message_content_union_str_json_schema(), + description="The message content sent by the user (can be a string or an array of multi-modal content parts)", + json_schema_extra=get_letta_user_message_content_union_str_json_schema(), ) @@ -258,10 +286,10 @@ class UpdateReasoningMessage(BaseModel): class UpdateAssistantMessage(BaseModel): message_type: Literal["assistant_message"] = "assistant_message" - content: Union[str, List[LettaMessageContentUnion]] = Field( + content: Union[str, List[LettaAssistantMessageContentUnion]] = Field( ..., description="The message content sent by the assistant (can be a string or an array of content parts)", - json_schema_extra=get_letta_message_content_union_str_json_schema(), + json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(), ) diff --git a/letta/schemas/letta_message_content.py b/letta/schemas/letta_message_content.py index 2c27f492..00ebfe78 100644 --- a/letta/schemas/letta_message_content.py +++ b/letta/schemas/letta_message_content.py @@ -1,11 +1,16 @@ from enum import Enum -from typing import Annotated, Literal, Union +from typing import Annotated, Literal, Optional, Union from pydantic import BaseModel, Field class MessageContentType(str, Enum): text = "text" + tool_call = "tool_call" + tool_return = "tool_return" + reasoning = "reasoning" + redacted_reasoning = "redacted_reasoning" + omitted_reasoning = "omitted_reasoning" class MessageContent(BaseModel): @@ -13,7 +18,7 @@ class MessageContent(BaseModel): # ------------------------------- -# Multi-Modal Content Types +# User Content Types # ------------------------------- @@ -22,13 +27,13 @@ class TextContent(MessageContent): text: str = Field(..., description="The text content of the message.") -LettaMessageContentUnion = Annotated[ +LettaUserMessageContentUnion = Annotated[ Union[TextContent], Field(discriminator="type"), ] -def create_letta_message_content_union_schema(): +def create_letta_user_message_content_union_schema(): return { "oneOf": [ {"$ref": "#/components/schemas/TextContent"}, @@ -42,6 +47,137 @@ def create_letta_message_content_union_schema(): } +def get_letta_user_message_content_union_str_json_schema(): + return { + "anyOf": [ + { + "type": "array", + "items": { + "$ref": "#/components/schemas/LettaUserMessageContentUnion", + }, + }, + {"type": "string"}, + ], + } + + +# ------------------------------- +# Assistant Content Types +# ------------------------------- + + +LettaAssistantMessageContentUnion = Annotated[ + Union[TextContent], + Field(discriminator="type"), +] + + +def create_letta_assistant_message_content_union_schema(): + return { + "oneOf": [ + {"$ref": "#/components/schemas/TextContent"}, + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "text": "#/components/schemas/TextContent", + }, + }, + } + + +def get_letta_assistant_message_content_union_str_json_schema(): + return { + "anyOf": [ + { + "type": "array", + "items": { + "$ref": "#/components/schemas/LettaAssistantMessageContentUnion", + }, + }, + {"type": "string"}, + ], + } + + +# ------------------------------- +# Intermediate Step Content Types +# ------------------------------- + + +class ToolCallContent(MessageContent): + type: Literal[MessageContentType.tool_call] = Field( + MessageContentType.tool_call, description="Indicates this content represents a tool call event." + ) + id: str = Field(..., description="A unique identifier for this specific tool call instance.") + name: str = Field(..., description="The name of the tool being called.") + input: dict = Field( + ..., description="The parameters being passed to the tool, structured as a dictionary of parameter names to values." + ) + + +class ToolReturnContent(MessageContent): + type: Literal[MessageContentType.tool_return] = Field( + MessageContentType.tool_return, description="Indicates this content represents a tool return event." + ) + tool_call_id: str = Field(..., description="References the ID of the ToolCallContent that initiated this tool call.") + content: str = Field(..., description="The content returned by the tool execution.") + is_error: bool = Field(..., description="Indicates whether the tool execution resulted in an error.") + + +class ReasoningContent(MessageContent): + type: Literal[MessageContentType.reasoning] = Field( + MessageContentType.reasoning, description="Indicates this is a reasoning/intermediate step." + ) + is_native: bool = Field(..., description="Whether the reasoning content was generated by a reasoner model that processed this step.") + reasoning: str = Field(..., description="The intermediate reasoning or thought process content.") + signature: Optional[str] = Field(None, description="A unique identifier for this reasoning step.") + + +class RedactedReasoningContent(MessageContent): + type: Literal[MessageContentType.redacted_reasoning] = Field( + MessageContentType.redacted_reasoning, description="Indicates this is a redacted thinking step." + ) + data: str = Field(..., description="The redacted or filtered intermediate reasoning content.") + + +class OmittedReasoningContent(MessageContent): + type: Literal[MessageContentType.omitted_reasoning] = Field( + MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step." + ) + tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.") + + +LettaMessageContentUnion = Annotated[ + Union[TextContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent], + Field(discriminator="type"), +] + + +def create_letta_message_content_union_schema(): + return { + "oneOf": [ + {"$ref": "#/components/schemas/TextContent"}, + {"$ref": "#/components/schemas/ToolCallContent"}, + {"$ref": "#/components/schemas/ToolReturnContent"}, + {"$ref": "#/components/schemas/ReasoningContent"}, + {"$ref": "#/components/schemas/RedactedReasoningContent"}, + {"$ref": "#/components/schemas/OmittedReasoningContent"}, + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "text": "#/components/schemas/TextContent", + "tool_call": "#/components/schemas/ToolCallContent", + "tool_return": "#/components/schemas/ToolCallContent", + "reasoning": "#/components/schemas/ReasoningContent", + "redacted_reasoning": "#/components/schemas/RedactedReasoningContent", + "omitted_reasoning": "#/components/schemas/OmittedReasoningContent", + }, + }, + } + + def get_letta_message_content_union_str_json_schema(): return { "anyOf": [ diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 8190f60d..3a34b8e8 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime @@ -27,7 +27,7 @@ from letta.schemas.letta_message import ( ToolReturnMessage, UserMessage, ) -from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent +from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent, get_letta_message_content_union_str_json_schema from letta.system import unpack_message @@ -65,15 +65,30 @@ class MessageCreate(BaseModel): MessageRole.user, MessageRole.system, ] = Field(..., description="The role of the participant.") - content: Union[str, List[LettaMessageContentUnion]] = Field(..., description="The content of the message.") + content: Union[str, List[LettaMessageContentUnion]] = Field( + ..., + description="The content of the message.", + json_schema_extra=get_letta_message_content_union_str_json_schema(), + ) name: Optional[str] = Field(None, description="The name of the participant.") + def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]: + data = super().model_dump(**kwargs) + if to_orm and "content" in data: + if isinstance(data["content"], str): + data["content"] = [TextContent(text=data["content"])] + return data + class MessageUpdate(BaseModel): """Request to update a message""" role: Optional[MessageRole] = Field(None, description="The role of the participant.") - content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(None, description="The content of the message.") + content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field( + None, + description="The content of the message.", + json_schema_extra=get_letta_message_content_union_str_json_schema(), + ) # NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message) # user_id: Optional[str] = Field(None, description="The unique identifier of the user.") # agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") @@ -89,12 +104,7 @@ class MessageUpdate(BaseModel): data = super().model_dump(**kwargs) if to_orm and "content" in data: if isinstance(data["content"], str): - data["text"] = data["content"] - else: - for content in data["content"]: - if content["type"] == "text": - data["text"] = content["text"] - del data["content"] + data["content"] = [TextContent(text=data["content"])] return data @@ -140,24 +150,6 @@ class Message(BaseMessage): assert v in roles, f"Role must be one of {roles}" return v - @model_validator(mode="before") - @classmethod - def convert_from_orm(cls, data: Dict[str, Any]) -> Dict[str, Any]: - if isinstance(data, dict): - if "text" in data and "content" not in data: - data["content"] = [TextContent(text=data["text"])] - del data["text"] - return data - - def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]: - data = super().model_dump(**kwargs) - if to_orm: - for content in data["content"]: - if content["type"] == "text": - data["text"] = content["text"] - del data["content"] - return data - def to_json(self): json_message = vars(self) if json_message["tool_calls"] is not None: diff --git a/letta/serialize_schemas/marshmallow_custom_fields.py b/letta/serialize_schemas/marshmallow_custom_fields.py index 4478659e..ebc7166d 100644 --- a/letta/serialize_schemas/marshmallow_custom_fields.py +++ b/letta/serialize_schemas/marshmallow_custom_fields.py @@ -3,10 +3,12 @@ from marshmallow import fields from letta.helpers.converters import ( deserialize_embedding_config, deserialize_llm_config, + deserialize_message_content, deserialize_tool_calls, deserialize_tool_rules, serialize_embedding_config, serialize_llm_config, + serialize_message_content, serialize_tool_calls, serialize_tool_rules, ) @@ -67,3 +69,13 @@ class ToolCallField(fields.Field): def _deserialize(self, value, attr, data, **kwargs): return deserialize_tool_calls(value) + + +class MessageContentField(fields.Field): + """Marshmallow field for handling a list of Message Content Part objects.""" + + def _serialize(self, value, attr, obj, **kwargs): + return serialize_message_content(value) + + def _deserialize(self, value, attr, data, **kwargs): + return deserialize_message_content(value) diff --git a/letta/serialize_schemas/pydantic_agent_schema.py b/letta/serialize_schemas/pydantic_agent_schema.py index 51087125..ce1d65ce 100644 --- a/letta/serialize_schemas/pydantic_agent_schema.py +++ b/letta/serialize_schemas/pydantic_agent_schema.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig @@ -27,7 +28,7 @@ class MessageSchema(BaseModel): model: Optional[str] name: Optional[str] role: str - text: str + content: List[TextContent] # TODO: Expand to more in the future tool_call_id: Optional[str] tool_calls: List[Any] tool_returns: List[Any] diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index a4013269..d6aa2ef8 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -17,7 +17,11 @@ from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaU from letta.log import get_logger from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError from letta.schemas.letta_message import create_letta_message_union_schema -from letta.schemas.letta_message_content import create_letta_message_content_union_schema +from letta.schemas.letta_message_content import ( + create_letta_assistant_message_content_union_schema, + create_letta_message_content_union_schema, + create_letta_user_message_content_union_schema, +) from letta.server.constants import REST_DEFAULT_PORT # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests @@ -70,6 +74,9 @@ def generate_openapi_schema(app: FastAPI): letta_docs["info"]["title"] = "Letta API" letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema() letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema() + letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema() + letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema() + for name, docs in [ ( "letta", diff --git a/letta/server/server.py b/letta/server/server.py index 336ba086..0a7de95f 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -740,7 +740,7 @@ class SyncServer(Server): Message( agent_id=agent_id, role=message.role, - content=[TextContent(text=message.content)], + content=[TextContent(text=message.content)] if message.content else [], name=message.name, # assigned later? model=None, diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index b07cb5d8..387168dc 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,7 +1,7 @@ import json from typing import List, Optional -from sqlalchemy import and_, or_ +from sqlalchemy import and_, exists, func, or_, select, text from letta.log import get_logger from letta.orm.agent import Agent as AgentModel @@ -233,9 +233,17 @@ class MessageManager: # Build a query that directly filters the Message table by agent_id. query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id) - # If query_text is provided, filter messages by partial match on text. + # If query_text is provided, filter messages using subquery. if query_text: - query = query.filter(MessageModel.text.ilike(f"%{query_text}%")) + content_element = func.json_array_elements(MessageModel.content).alias("content_element") + query = query.filter( + exists( + select(1) + .select_from(content_element) + .where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text")) + .params(query_text=f"%{query_text}%") + ) + ) # If role is provided, filter messages by role. if role: diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index 0fe87329..b2607291 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -482,7 +482,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent agent_id = serialize_test_agent.id # Step 1: Download the serialized agent - response = fastapi_client.get(f"/v1/agents/{agent_id}/download", headers={"user_id": default_user.id}) + response = fastapi_client.get(f"/v1/agents/{agent_id}/export", headers={"user_id": default_user.id}) assert response.status_code == 200, f"Download failed: {response.text}" # Ensure response matches expected schema @@ -493,7 +493,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8")) files = {"file": ("agent.json", agent_bytes, "application/json")} upload_response = fastapi_client.post( - "/v1/agents/upload", + "/v1/agents/import", headers={"user_id": other_user.id}, params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id}, files=files, diff --git a/tests/test_managers.py b/tests/test_managers.py index 48416d16..1d1dc240 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -25,6 +25,7 @@ from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPro from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate, LettaRequestConfig from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage +from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate @@ -272,7 +273,7 @@ def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent): organization_id=default_user.organization_id, agent_id=sarah_agent.id, role="user", - text="Hello, world!", + content=[TextContent(text="Hello, world!")], ) msg = server.message_manager.create_message(message, actor=default_user) @@ -1196,7 +1197,7 @@ def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent, agent_id=sarah_agent.id, organization_id=default_user.organization_id, role="user", - text="Hello, Sarah!", + content=[TextContent(text="Hello, Sarah!")], ), actor=default_user, ) @@ -1205,7 +1206,7 @@ def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent, agent_id=sarah_agent.id, organization_id=default_user.organization_id, role="assistant", - text="Hello, user!", + content=[TextContent(text="Hello, user!")], ), actor=default_user, ) @@ -1236,7 +1237,7 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use agent_id=sarah_agent.id, organization_id=default_user.organization_id, role="user", - text="Hello, Sarah!", + content=[TextContent(text="Hello, Sarah!")], ), actor=default_user, ) @@ -2062,7 +2063,10 @@ def test_message_size(server: SyncServer, hello_world_message_fixture, default_u # Create additional test messages messages = [ PydanticMessage( - organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}" + organization_id=default_user.organization_id, + agent_id=base_message.agent_id, + role=base_message.role, + content=[TextContent(text=f"Test message {i}")], ) for i in range(4) ] @@ -2090,7 +2094,10 @@ def create_test_messages(server: SyncServer, base_message: PydanticMessage, defa """Helper function to create test messages for all tests""" messages = [ PydanticMessage( - organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}" + organization_id=default_user.organization_id, + agent_id=base_message.agent_id, + role=base_message.role, + content=[TextContent(text=f"Test message {i}")], ) for i in range(4) ] @@ -3270,7 +3277,7 @@ def test_job_messages_pagination(server: SyncServer, default_run, default_user, organization_id=default_user.organization_id, agent_id=sarah_agent.id, role=MessageRole.user, - text=f"Test message {i}", + content=[TextContent(text=f"Test message {i}")], ) msg = server.message_manager.create_message(message, actor=default_user) message_ids.append(msg.id) @@ -3383,7 +3390,7 @@ def test_job_messages_ordering(server: SyncServer, default_run, default_user, sa for i, created_at in enumerate(message_times): message = PydanticMessage( role=MessageRole.user, - text="Test message", + content=[TextContent(text="Test message")], organization_id=default_user.organization_id, agent_id=sarah_agent.id, created_at=created_at, @@ -3452,19 +3459,19 @@ def test_job_messages_filter(server: SyncServer, default_run, default_user, sara messages = [ PydanticMessage( role=MessageRole.user, - text="Hello", + content=[TextContent(text="Hello")], organization_id=default_user.organization_id, agent_id=sarah_agent.id, ), PydanticMessage( role=MessageRole.assistant, - text="Hi there!", + content=[TextContent(text="Hi there!")], organization_id=default_user.organization_id, agent_id=sarah_agent.id, ), PydanticMessage( role=MessageRole.assistant, - text="Let me help you with that", + content=[TextContent(text="Let me help you with that")], organization_id=default_user.organization_id, agent_id=sarah_agent.id, tool_calls=[ @@ -3519,7 +3526,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_ organization_id=default_user.organization_id, agent_id=sarah_agent.id, role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, - text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}', + content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')], tool_calls=( [{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] if i % 2 == 1 @@ -3570,7 +3577,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_ organization_id=default_user.organization_id, agent_id=sarah_agent.id, role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, - text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}', + content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')], tool_calls=( [{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] if i % 2 == 1