feat: add content parts to message schema (#1273)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
"""add content parts to message
|
||||
|
||||
Revision ID: 2cceb07c2384
|
||||
Revises: 77de976590ae
|
||||
Create Date: 2025-03-13 14:30:53.177061
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.orm.custom_columns import MessageContentColumn
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2cceb07c2384"
|
||||
down_revision: Union[str, None] = "77de976590ae"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("messages", sa.Column("content", MessageContentColumn(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("messages", "content")
|
||||
# ### end Alembic commands ###
|
||||
@@ -8,6 +8,16 @@ from sqlalchemy import Dialect
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.letta_message_content import (
|
||||
MessageContent,
|
||||
MessageContentType,
|
||||
OmittedReasoningContent,
|
||||
ReasoningContent,
|
||||
RedactedReasoningContent,
|
||||
TextContent,
|
||||
ToolCallContent,
|
||||
ToolReturnContent,
|
||||
)
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule
|
||||
@@ -166,6 +176,60 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
|
||||
return tool_returns
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# MessageContent Serialization
|
||||
# ----------------------------
|
||||
|
||||
|
||||
def serialize_message_content(message_content: Optional[List[Union[MessageContent, dict]]]) -> List[Dict]:
|
||||
"""Convert a list of MessageContent objects into JSON-serializable format."""
|
||||
if not message_content:
|
||||
return []
|
||||
|
||||
serialized_message_content = []
|
||||
for content in message_content:
|
||||
if isinstance(content, MessageContent):
|
||||
serialized_message_content.append(content.model_dump())
|
||||
elif isinstance(content, dict):
|
||||
serialized_message_content.append(content) # Already a dictionary, leave it as-is
|
||||
else:
|
||||
raise TypeError(f"Unexpected message content type: {type(content)}")
|
||||
|
||||
return serialized_message_content
|
||||
|
||||
|
||||
def deserialize_message_content(data: Optional[List[Dict]]) -> List[MessageContent]:
|
||||
"""Convert a JSON list back into MessageContent objects."""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
message_content = []
|
||||
for item in data:
|
||||
if not item:
|
||||
continue
|
||||
|
||||
content_type = item.get("type")
|
||||
if content_type == MessageContentType.text:
|
||||
content = TextContent(**item)
|
||||
elif content_type == MessageContentType.tool_call:
|
||||
content = ToolCallContent(**item)
|
||||
elif content_type == MessageContentType.tool_return:
|
||||
content = ToolReturnContent(**item)
|
||||
elif content_type == MessageContentType.reasoning:
|
||||
content = ReasoningContent(**item)
|
||||
elif content_type == MessageContentType.redacted_reasoning:
|
||||
content = RedactedReasoningContent(**item)
|
||||
elif content_type == MessageContentType.omitted_reasoning:
|
||||
content = OmittedReasoningContent(**item)
|
||||
else:
|
||||
# Skip invalid content
|
||||
continue
|
||||
|
||||
message_content.append(content)
|
||||
|
||||
return message_content
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Vector Serialization
|
||||
# --------------------------
|
||||
|
||||
@@ -221,7 +221,7 @@ def openai_chat_completions_process_stream(
|
||||
# TODO(sarah): add message ID generation function
|
||||
dummy_message = _Message(
|
||||
role=_MessageRole.assistant,
|
||||
text="",
|
||||
content=[],
|
||||
agent_id="",
|
||||
model="",
|
||||
name=None,
|
||||
|
||||
@@ -4,12 +4,14 @@ from sqlalchemy.types import BINARY, TypeDecorator
|
||||
from letta.helpers.converters import (
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_message_content,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_returns,
|
||||
deserialize_tool_rules,
|
||||
deserialize_vector,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_message_content,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_returns,
|
||||
serialize_tool_rules,
|
||||
@@ -82,6 +84,19 @@ class ToolReturnColumn(TypeDecorator):
|
||||
return deserialize_tool_returns(value)
|
||||
|
||||
|
||||
class MessageContentColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing the content parts of a message as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_message_content(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_message_content(value)
|
||||
|
||||
|
||||
class CommonVector(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing vectors in SQLite."""
|
||||
|
||||
|
||||
@@ -4,9 +4,10 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
|
||||
from sqlalchemy import ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.letta_message_content import MessageContent
|
||||
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import ToolReturn
|
||||
@@ -25,6 +26,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique message identifier")
|
||||
role: Mapped[str] = mapped_column(doc="Message role (user/assistant/system/tool)")
|
||||
text: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Message content")
|
||||
content: Mapped[List[MessageContent]] = mapped_column(MessageContentColumn, nullable=True, doc="Message content parts")
|
||||
model: Mapped[Optional[str]] = mapped_column(nullable=True, doc="LLM model used")
|
||||
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
|
||||
tool_calls: Mapped[List[OpenAIToolCall]] = mapped_column(ToolCallColumn, doc="Tool call information")
|
||||
@@ -54,8 +56,8 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
return self.job_message.job if self.job_message else None
|
||||
|
||||
def to_pydantic(self) -> PydanticMessage:
|
||||
"""custom pydantic conversion for message content mapping"""
|
||||
"""Custom pydantic conversion to handle data using legacy text field"""
|
||||
model = self.__pydantic_model__.model_validate(self)
|
||||
if self.text:
|
||||
if self.text and not model.content:
|
||||
model.content = [PydanticTextContent(text=self.text)]
|
||||
return model
|
||||
|
||||
@@ -4,7 +4,12 @@ from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
from letta.schemas.letta_message_content import LettaMessageContentUnion, get_letta_message_content_union_str_json_schema
|
||||
from letta.schemas.letta_message_content import (
|
||||
LettaAssistantMessageContentUnion,
|
||||
LettaUserMessageContentUnion,
|
||||
get_letta_assistant_message_content_union_str_json_schema,
|
||||
get_letta_user_message_content_union_str_json_schema,
|
||||
)
|
||||
|
||||
# ---------------------------
|
||||
# Letta API Messaging Schemas
|
||||
@@ -20,11 +25,12 @@ class LettaMessage(BaseModel):
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
"""
|
||||
|
||||
id: str
|
||||
date: datetime
|
||||
name: Optional[str] = None
|
||||
|
||||
@field_serializer("date")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
@@ -44,15 +50,12 @@ class SystemMessage(LettaMessage):
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the system (can be a string or an array of content parts)
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
content (str): The message content sent by the system
|
||||
"""
|
||||
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the system (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
content: str = Field(..., description="The message content sent by the system")
|
||||
|
||||
|
||||
class UserMessage(LettaMessage):
|
||||
@@ -62,14 +65,15 @@ class UserMessage(LettaMessage):
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts)
|
||||
"""
|
||||
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the user (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
|
||||
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
@@ -80,10 +84,33 @@ class ReasoningMessage(LettaMessage):
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
source (Literal["reasoner_model", "non_reasoner_model"]): Whether the reasoning
|
||||
content was generated natively by a reasoner model or derived via prompting
|
||||
reasoning (str): The internal reasoning of the agent
|
||||
"""
|
||||
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model"
|
||||
reasoning: str
|
||||
|
||||
|
||||
class HiddenReasoningMessage(LettaMessage):
|
||||
"""
|
||||
Representation of an agent's internal reasoning where reasoning content
|
||||
has been hidden from the response.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
state (Literal["redacted", "omitted"]): Whether the reasoning
|
||||
content was redacted by the provider or simply omitted by the API
|
||||
reasoning (str): The internal reasoning of the agent
|
||||
"""
|
||||
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
state: Literal["redacted", "omitted"]
|
||||
reasoning: str
|
||||
|
||||
|
||||
@@ -117,6 +144,7 @@ class ToolCallMessage(LettaMessage):
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
|
||||
"""
|
||||
|
||||
@@ -164,6 +192,7 @@ class ToolReturnMessage(LettaMessage):
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
tool_return (str): The return value of the tool
|
||||
status (Literal["success", "error"]): The status of the tool call
|
||||
tool_call_id (str): A unique identifier for the tool call that generated this message
|
||||
@@ -186,14 +215,15 @@ class AssistantMessage(LettaMessage):
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
|
||||
"""
|
||||
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the agent (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
@@ -235,19 +265,17 @@ def create_letta_message_union_schema():
|
||||
|
||||
class UpdateSystemMessage(BaseModel):
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the system (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
content: str = Field(
|
||||
..., description="The message content sent by the system (can be a string or an array of multi-modal content parts)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateUserMessage(BaseModel):
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the user (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
|
||||
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
@@ -258,10 +286,10 @@ class UpdateReasoningMessage(BaseModel):
|
||||
|
||||
class UpdateAssistantMessage(BaseModel):
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the assistant (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageContentType(str, Enum):
|
||||
text = "text"
|
||||
tool_call = "tool_call"
|
||||
tool_return = "tool_return"
|
||||
reasoning = "reasoning"
|
||||
redacted_reasoning = "redacted_reasoning"
|
||||
omitted_reasoning = "omitted_reasoning"
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
@@ -13,7 +18,7 @@ class MessageContent(BaseModel):
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Multi-Modal Content Types
|
||||
# User Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
@@ -22,13 +27,13 @@ class TextContent(MessageContent):
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
LettaMessageContentUnion = Annotated[
|
||||
LettaUserMessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_message_content_union_schema():
|
||||
def create_letta_user_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
@@ -42,6 +47,137 @@ def create_letta_message_content_union_schema():
|
||||
}
|
||||
|
||||
|
||||
def get_letta_user_message_content_union_str_json_schema():
|
||||
return {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaUserMessageContentUnion",
|
||||
},
|
||||
},
|
||||
{"type": "string"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Assistant Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
LettaAssistantMessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_assistant_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_letta_assistant_message_content_union_str_json_schema():
|
||||
return {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaAssistantMessageContentUnion",
|
||||
},
|
||||
},
|
||||
{"type": "string"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Intermediate Step Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
class ToolCallContent(MessageContent):
|
||||
type: Literal[MessageContentType.tool_call] = Field(
|
||||
MessageContentType.tool_call, description="Indicates this content represents a tool call event."
|
||||
)
|
||||
id: str = Field(..., description="A unique identifier for this specific tool call instance.")
|
||||
name: str = Field(..., description="The name of the tool being called.")
|
||||
input: dict = Field(
|
||||
..., description="The parameters being passed to the tool, structured as a dictionary of parameter names to values."
|
||||
)
|
||||
|
||||
|
||||
class ToolReturnContent(MessageContent):
|
||||
type: Literal[MessageContentType.tool_return] = Field(
|
||||
MessageContentType.tool_return, description="Indicates this content represents a tool return event."
|
||||
)
|
||||
tool_call_id: str = Field(..., description="References the ID of the ToolCallContent that initiated this tool call.")
|
||||
content: str = Field(..., description="The content returned by the tool execution.")
|
||||
is_error: bool = Field(..., description="Indicates whether the tool execution resulted in an error.")
|
||||
|
||||
|
||||
class ReasoningContent(MessageContent):
|
||||
type: Literal[MessageContentType.reasoning] = Field(
|
||||
MessageContentType.reasoning, description="Indicates this is a reasoning/intermediate step."
|
||||
)
|
||||
is_native: bool = Field(..., description="Whether the reasoning content was generated by a reasoner model that processed this step.")
|
||||
reasoning: str = Field(..., description="The intermediate reasoning or thought process content.")
|
||||
signature: Optional[str] = Field(None, description="A unique identifier for this reasoning step.")
|
||||
|
||||
|
||||
class RedactedReasoningContent(MessageContent):
|
||||
type: Literal[MessageContentType.redacted_reasoning] = Field(
|
||||
MessageContentType.redacted_reasoning, description="Indicates this is a redacted thinking step."
|
||||
)
|
||||
data: str = Field(..., description="The redacted or filtered intermediate reasoning content.")
|
||||
|
||||
|
||||
class OmittedReasoningContent(MessageContent):
|
||||
type: Literal[MessageContentType.omitted_reasoning] = Field(
|
||||
MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
|
||||
)
|
||||
tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
|
||||
|
||||
|
||||
LettaMessageContentUnion = Annotated[
|
||||
Union[TextContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
{"$ref": "#/components/schemas/ToolCallContent"},
|
||||
{"$ref": "#/components/schemas/ToolReturnContent"},
|
||||
{"$ref": "#/components/schemas/ReasoningContent"},
|
||||
{"$ref": "#/components/schemas/RedactedReasoningContent"},
|
||||
{"$ref": "#/components/schemas/OmittedReasoningContent"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
"tool_call": "#/components/schemas/ToolCallContent",
|
||||
"tool_return": "#/components/schemas/ToolCallContent",
|
||||
"reasoning": "#/components/schemas/ReasoningContent",
|
||||
"redacted_reasoning": "#/components/schemas/RedactedReasoningContent",
|
||||
"omitted_reasoning": "#/components/schemas/OmittedReasoningContent",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_letta_message_content_union_str_json_schema():
|
||||
return {
|
||||
"anyOf": [
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
||||
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
|
||||
@@ -27,7 +27,7 @@ from letta.schemas.letta_message import (
|
||||
ToolReturnMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent
|
||||
from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent, get_letta_message_content_union_str_json_schema
|
||||
from letta.system import unpack_message
|
||||
|
||||
|
||||
@@ -65,15 +65,30 @@ class MessageCreate(BaseModel):
|
||||
MessageRole.user,
|
||||
MessageRole.system,
|
||||
] = Field(..., description="The role of the participant.")
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(..., description="The content of the message.")
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The content of the message.",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm and "content" in data:
|
||||
if isinstance(data["content"], str):
|
||||
data["content"] = [TextContent(text=data["content"])]
|
||||
return data
|
||||
|
||||
|
||||
class MessageUpdate(BaseModel):
|
||||
"""Request to update a message"""
|
||||
|
||||
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
||||
content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(None, description="The content of the message.")
|
||||
content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(
|
||||
None,
|
||||
description="The content of the message.",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
||||
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
||||
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
@@ -89,12 +104,7 @@ class MessageUpdate(BaseModel):
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm and "content" in data:
|
||||
if isinstance(data["content"], str):
|
||||
data["text"] = data["content"]
|
||||
else:
|
||||
for content in data["content"]:
|
||||
if content["type"] == "text":
|
||||
data["text"] = content["text"]
|
||||
del data["content"]
|
||||
data["content"] = [TextContent(text=data["content"])]
|
||||
return data
|
||||
|
||||
|
||||
@@ -140,24 +150,6 @@ class Message(BaseMessage):
|
||||
assert v in roles, f"Role must be one of {roles}"
|
||||
return v
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def convert_from_orm(cls, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if isinstance(data, dict):
|
||||
if "text" in data and "content" not in data:
|
||||
data["content"] = [TextContent(text=data["text"])]
|
||||
del data["text"]
|
||||
return data
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm:
|
||||
for content in data["content"]:
|
||||
if content["type"] == "text":
|
||||
data["text"] = content["text"]
|
||||
del data["content"]
|
||||
return data
|
||||
|
||||
def to_json(self):
|
||||
json_message = vars(self)
|
||||
if json_message["tool_calls"] is not None:
|
||||
|
||||
@@ -3,10 +3,12 @@ from marshmallow import fields
|
||||
from letta.helpers.converters import (
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_message_content,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_rules,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_message_content,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_rules,
|
||||
)
|
||||
@@ -67,3 +69,13 @@ class ToolCallField(fields.Field):
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_tool_calls(value)
|
||||
|
||||
|
||||
class MessageContentField(fields.Field):
|
||||
"""Marshmallow field for handling a list of Message Content Part objects."""
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_message_content(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_message_content(value)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
@@ -27,7 +28,7 @@ class MessageSchema(BaseModel):
|
||||
model: Optional[str]
|
||||
name: Optional[str]
|
||||
role: str
|
||||
text: str
|
||||
content: List[TextContent] # TODO: Expand to more in the future
|
||||
tool_call_id: Optional[str]
|
||||
tool_calls: List[Any]
|
||||
tool_returns: List[Any]
|
||||
|
||||
@@ -17,7 +17,11 @@ from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaU
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import create_letta_message_union_schema
|
||||
from letta.schemas.letta_message_content import create_letta_message_content_union_schema
|
||||
from letta.schemas.letta_message_content import (
|
||||
create_letta_assistant_message_content_union_schema,
|
||||
create_letta_message_content_union_schema,
|
||||
create_letta_user_message_content_union_schema,
|
||||
)
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
@@ -70,6 +74,9 @@ def generate_openapi_schema(app: FastAPI):
|
||||
letta_docs["info"]["title"] = "Letta API"
|
||||
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema()
|
||||
|
||||
for name, docs in [
|
||||
(
|
||||
"letta",
|
||||
|
||||
@@ -740,7 +740,7 @@ class SyncServer(Server):
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
content=[TextContent(text=message.content)] if message.content else [],
|
||||
name=message.name,
|
||||
# assigned later?
|
||||
model=None,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy import and_, exists, func, or_, select, text
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
@@ -233,9 +233,17 @@ class MessageManager:
|
||||
# Build a query that directly filters the Message table by agent_id.
|
||||
query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id)
|
||||
|
||||
# If query_text is provided, filter messages by partial match on text.
|
||||
# If query_text is provided, filter messages using subquery.
|
||||
if query_text:
|
||||
query = query.filter(MessageModel.text.ilike(f"%{query_text}%"))
|
||||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||||
query = query.filter(
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(content_element)
|
||||
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
|
||||
.params(query_text=f"%{query_text}%")
|
||||
)
|
||||
)
|
||||
|
||||
# If role is provided, filter messages by role.
|
||||
if role:
|
||||
|
||||
@@ -482,7 +482,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
agent_id = serialize_test_agent.id
|
||||
|
||||
# Step 1: Download the serialized agent
|
||||
response = fastapi_client.get(f"/v1/agents/{agent_id}/download", headers={"user_id": default_user.id})
|
||||
response = fastapi_client.get(f"/v1/agents/{agent_id}/export", headers={"user_id": default_user.id})
|
||||
assert response.status_code == 200, f"Download failed: {response.text}"
|
||||
|
||||
# Ensure response matches expected schema
|
||||
@@ -493,7 +493,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
|
||||
files = {"file": ("agent.json", agent_bytes, "application/json")}
|
||||
upload_response = fastapi_client.post(
|
||||
"/v1/agents/upload",
|
||||
"/v1/agents/import",
|
||||
headers={"user_id": other_user.id},
|
||||
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
|
||||
files=files,
|
||||
|
||||
@@ -25,6 +25,7 @@ from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPro
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
@@ -272,7 +273,7 @@ def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
text="Hello, world!",
|
||||
content=[TextContent(text="Hello, world!")],
|
||||
)
|
||||
|
||||
msg = server.message_manager.create_message(message, actor=default_user)
|
||||
@@ -1196,7 +1197,7 @@ def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent,
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="user",
|
||||
text="Hello, Sarah!",
|
||||
content=[TextContent(text="Hello, Sarah!")],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1205,7 +1206,7 @@ def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent,
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="assistant",
|
||||
text="Hello, user!",
|
||||
content=[TextContent(text="Hello, user!")],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1236,7 +1237,7 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="user",
|
||||
text="Hello, Sarah!",
|
||||
content=[TextContent(text="Hello, Sarah!")],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -2062,7 +2063,10 @@ def test_message_size(server: SyncServer, hello_world_message_fixture, default_u
|
||||
# Create additional test messages
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=base_message.agent_id,
|
||||
role=base_message.role,
|
||||
content=[TextContent(text=f"Test message {i}")],
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
@@ -2090,7 +2094,10 @@ def create_test_messages(server: SyncServer, base_message: PydanticMessage, defa
|
||||
"""Helper function to create test messages for all tests"""
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=base_message.agent_id,
|
||||
role=base_message.role,
|
||||
content=[TextContent(text=f"Test message {i}")],
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
@@ -3270,7 +3277,7 @@ def test_job_messages_pagination(server: SyncServer, default_run, default_user,
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
text=f"Test message {i}",
|
||||
content=[TextContent(text=f"Test message {i}")],
|
||||
)
|
||||
msg = server.message_manager.create_message(message, actor=default_user)
|
||||
message_ids.append(msg.id)
|
||||
@@ -3383,7 +3390,7 @@ def test_job_messages_ordering(server: SyncServer, default_run, default_user, sa
|
||||
for i, created_at in enumerate(message_times):
|
||||
message = PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
text="Test message",
|
||||
content=[TextContent(text="Test message")],
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
created_at=created_at,
|
||||
@@ -3452,19 +3459,19 @@ def test_job_messages_filter(server: SyncServer, default_run, default_user, sara
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
text="Hello",
|
||||
content=[TextContent(text="Hello")],
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
text="Hi there!",
|
||||
content=[TextContent(text="Hi there!")],
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
text="Let me help you with that",
|
||||
content=[TextContent(text="Let me help you with that")],
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
tool_calls=[
|
||||
@@ -3519,7 +3526,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
|
||||
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
|
||||
content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')],
|
||||
tool_calls=(
|
||||
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
|
||||
if i % 2 == 1
|
||||
@@ -3570,7 +3577,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
|
||||
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
|
||||
content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')],
|
||||
tool_calls=(
|
||||
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
|
||||
if i % 2 == 1
|
||||
|
||||
Reference in New Issue
Block a user