feat: add content parts to message schema (#1273)

Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
cthomas
2025-03-13 18:43:32 -07:00
committed by GitHub
parent b7eb47a4a0
commit aa2f4258c4
15 changed files with 385 additions and 81 deletions

View File

@@ -0,0 +1,32 @@
"""add content parts to message
Revision ID: 2cceb07c2384
Revises: 77de976590ae
Create Date: 2025-03-13 14:30:53.177061
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from letta.orm.custom_columns import MessageContentColumn
# revision identifiers, used by Alembic.
revision: str = "2cceb07c2384"
down_revision: Union[str, None] = "77de976590ae"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("content", MessageContentColumn(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("messages", "content")
# ### end Alembic commands ###

View File

@@ -8,6 +8,16 @@ from sqlalchemy import Dialect
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.letta_message_content import (
MessageContent,
MessageContentType,
OmittedReasoningContent,
ReasoningContent,
RedactedReasoningContent,
TextContent,
ToolCallContent,
ToolReturnContent,
)
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import ToolReturn
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule
@@ -166,6 +176,60 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
return tool_returns
# ----------------------------
# MessageContent Serialization
# ----------------------------
def serialize_message_content(message_content: Optional[List[Union[MessageContent, dict]]]) -> List[Dict]:
"""Convert a list of MessageContent objects into JSON-serializable format."""
if not message_content:
return []
serialized_message_content = []
for content in message_content:
if isinstance(content, MessageContent):
serialized_message_content.append(content.model_dump())
elif isinstance(content, dict):
serialized_message_content.append(content) # Already a dictionary, leave it as-is
else:
raise TypeError(f"Unexpected message content type: {type(content)}")
return serialized_message_content
def deserialize_message_content(data: Optional[List[Dict]]) -> List[MessageContent]:
"""Convert a JSON list back into MessageContent objects."""
if not data:
return []
message_content = []
for item in data:
if not item:
continue
content_type = item.get("type")
if content_type == MessageContentType.text:
content = TextContent(**item)
elif content_type == MessageContentType.tool_call:
content = ToolCallContent(**item)
elif content_type == MessageContentType.tool_return:
content = ToolReturnContent(**item)
elif content_type == MessageContentType.reasoning:
content = ReasoningContent(**item)
elif content_type == MessageContentType.redacted_reasoning:
content = RedactedReasoningContent(**item)
elif content_type == MessageContentType.omitted_reasoning:
content = OmittedReasoningContent(**item)
else:
# Skip invalid content
continue
message_content.append(content)
return message_content
# --------------------------
# Vector Serialization
# --------------------------

View File

@@ -221,7 +221,7 @@ def openai_chat_completions_process_stream(
# TODO(sarah): add message ID generation function
dummy_message = _Message(
role=_MessageRole.assistant,
text="",
content=[],
agent_id="",
model="",
name=None,

View File

@@ -4,12 +4,14 @@ from sqlalchemy.types import BINARY, TypeDecorator
from letta.helpers.converters import (
deserialize_embedding_config,
deserialize_llm_config,
deserialize_message_content,
deserialize_tool_calls,
deserialize_tool_returns,
deserialize_tool_rules,
deserialize_vector,
serialize_embedding_config,
serialize_llm_config,
serialize_message_content,
serialize_tool_calls,
serialize_tool_returns,
serialize_tool_rules,
@@ -82,6 +84,19 @@ class ToolReturnColumn(TypeDecorator):
return deserialize_tool_returns(value)
class MessageContentColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing the content parts of a message as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_message_content(value)
def process_result_value(self, value, dialect):
return deserialize_message_content(value)
class CommonVector(TypeDecorator):
"""Custom SQLAlchemy column type for storing vectors in SQLite."""

View File

@@ -4,9 +4,10 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
from sqlalchemy import ForeignKey, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import ToolCallColumn, ToolReturnColumn
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.letta_message_content import MessageContent
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import ToolReturn
@@ -25,6 +26,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique message identifier")
role: Mapped[str] = mapped_column(doc="Message role (user/assistant/system/tool)")
text: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Message content")
content: Mapped[List[MessageContent]] = mapped_column(MessageContentColumn, nullable=True, doc="Message content parts")
model: Mapped[Optional[str]] = mapped_column(nullable=True, doc="LLM model used")
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
tool_calls: Mapped[List[OpenAIToolCall]] = mapped_column(ToolCallColumn, doc="Tool call information")
@@ -54,8 +56,8 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
return self.job_message.job if self.job_message else None
def to_pydantic(self) -> PydanticMessage:
"""custom pydantic conversion for message content mapping"""
"""Custom pydantic conversion to handle data using legacy text field"""
model = self.__pydantic_model__.model_validate(self)
if self.text:
if self.text and not model.content:
model.content = [PydanticTextContent(text=self.text)]
return model

View File

@@ -4,7 +4,12 @@ from typing import Annotated, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
from letta.schemas.letta_message_content import LettaMessageContentUnion, get_letta_message_content_union_str_json_schema
from letta.schemas.letta_message_content import (
LettaAssistantMessageContentUnion,
LettaUserMessageContentUnion,
get_letta_assistant_message_content_union_str_json_schema,
get_letta_user_message_content_union_str_json_schema,
)
# ---------------------------
# Letta API Messaging Schemas
@@ -20,11 +25,12 @@ class LettaMessage(BaseModel):
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
"""
id: str
date: datetime
name: Optional[str] = None
@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
@@ -44,15 +50,12 @@ class SystemMessage(LettaMessage):
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the system (can be a string or an array of content parts)
name (Optional[str]): The name of the sender of the message
content (str): The message content sent by the system
"""
message_type: Literal["system_message"] = "system_message"
content: Union[str, List[LettaMessageContentUnion]] = Field(
...,
description="The message content sent by the system (can be a string or an array of content parts)",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
)
content: str = Field(..., description="The message content sent by the system")
class UserMessage(LettaMessage):
@@ -62,14 +65,15 @@ class UserMessage(LettaMessage):
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
name (Optional[str]): The name of the sender of the message
content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts)
"""
message_type: Literal["user_message"] = "user_message"
content: Union[str, List[LettaMessageContentUnion]] = Field(
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
...,
description="The message content sent by the user (can be a string or an array of content parts)",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
)
@@ -80,10 +84,33 @@ class ReasoningMessage(LettaMessage):
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
source (Literal["reasoner_model", "non_reasoner_model"]): Whether the reasoning
content was generated natively by a reasoner model or derived via prompting
reasoning (str): The internal reasoning of the agent
"""
message_type: Literal["reasoning_message"] = "reasoning_message"
source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model"
reasoning: str
class HiddenReasoningMessage(LettaMessage):
"""
Representation of an agent's internal reasoning where reasoning content
has been hidden from the response.
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
state (Literal["redacted", "omitted"]): Whether the reasoning
content was redacted by the provider or simply omitted by the API
reasoning (str): The internal reasoning of the agent
"""
message_type: Literal["reasoning_message"] = "reasoning_message"
state: Literal["redacted", "omitted"]
reasoning: str
@@ -117,6 +144,7 @@ class ToolCallMessage(LettaMessage):
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
"""
@@ -164,6 +192,7 @@ class ToolReturnMessage(LettaMessage):
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
tool_return (str): The return value of the tool
status (Literal["success", "error"]): The status of the tool call
tool_call_id (str): A unique identifier for the tool call that generated this message
@@ -186,14 +215,15 @@ class AssistantMessage(LettaMessage):
Args:
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
name (Optional[str]): The name of the sender of the message
content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
"""
message_type: Literal["assistant_message"] = "assistant_message"
content: Union[str, List[LettaMessageContentUnion]] = Field(
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
...,
description="The message content sent by the agent (can be a string or an array of content parts)",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
)
@@ -235,19 +265,17 @@ def create_letta_message_union_schema():
class UpdateSystemMessage(BaseModel):
message_type: Literal["system_message"] = "system_message"
content: Union[str, List[LettaMessageContentUnion]] = Field(
...,
description="The message content sent by the system (can be a string or an array of content parts)",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
content: str = Field(
..., description="The message content sent by the system (can be a string or an array of multi-modal content parts)"
)
class UpdateUserMessage(BaseModel):
message_type: Literal["user_message"] = "user_message"
content: Union[str, List[LettaMessageContentUnion]] = Field(
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
...,
description="The message content sent by the user (can be a string or an array of content parts)",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
)
@@ -258,10 +286,10 @@ class UpdateReasoningMessage(BaseModel):
class UpdateAssistantMessage(BaseModel):
message_type: Literal["assistant_message"] = "assistant_message"
content: Union[str, List[LettaMessageContentUnion]] = Field(
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
...,
description="The message content sent by the assistant (can be a string or an array of content parts)",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
)

View File

@@ -1,11 +1,16 @@
from enum import Enum
from typing import Annotated, Literal, Union
from typing import Annotated, Literal, Optional, Union
from pydantic import BaseModel, Field
class MessageContentType(str, Enum):
text = "text"
tool_call = "tool_call"
tool_return = "tool_return"
reasoning = "reasoning"
redacted_reasoning = "redacted_reasoning"
omitted_reasoning = "omitted_reasoning"
class MessageContent(BaseModel):
@@ -13,7 +18,7 @@ class MessageContent(BaseModel):
# -------------------------------
# Multi-Modal Content Types
# User Content Types
# -------------------------------
@@ -22,13 +27,13 @@ class TextContent(MessageContent):
text: str = Field(..., description="The text content of the message.")
LettaMessageContentUnion = Annotated[
LettaUserMessageContentUnion = Annotated[
Union[TextContent],
Field(discriminator="type"),
]
def create_letta_message_content_union_schema():
def create_letta_user_message_content_union_schema():
return {
"oneOf": [
{"$ref": "#/components/schemas/TextContent"},
@@ -42,6 +47,137 @@ def create_letta_message_content_union_schema():
}
def get_letta_user_message_content_union_str_json_schema():
return {
"anyOf": [
{
"type": "array",
"items": {
"$ref": "#/components/schemas/LettaUserMessageContentUnion",
},
},
{"type": "string"},
],
}
# -------------------------------
# Assistant Content Types
# -------------------------------
LettaAssistantMessageContentUnion = Annotated[
Union[TextContent],
Field(discriminator="type"),
]
def create_letta_assistant_message_content_union_schema():
return {
"oneOf": [
{"$ref": "#/components/schemas/TextContent"},
],
"discriminator": {
"propertyName": "type",
"mapping": {
"text": "#/components/schemas/TextContent",
},
},
}
def get_letta_assistant_message_content_union_str_json_schema():
return {
"anyOf": [
{
"type": "array",
"items": {
"$ref": "#/components/schemas/LettaAssistantMessageContentUnion",
},
},
{"type": "string"},
],
}
# -------------------------------
# Intermediate Step Content Types
# -------------------------------
class ToolCallContent(MessageContent):
type: Literal[MessageContentType.tool_call] = Field(
MessageContentType.tool_call, description="Indicates this content represents a tool call event."
)
id: str = Field(..., description="A unique identifier for this specific tool call instance.")
name: str = Field(..., description="The name of the tool being called.")
input: dict = Field(
..., description="The parameters being passed to the tool, structured as a dictionary of parameter names to values."
)
class ToolReturnContent(MessageContent):
type: Literal[MessageContentType.tool_return] = Field(
MessageContentType.tool_return, description="Indicates this content represents a tool return event."
)
tool_call_id: str = Field(..., description="References the ID of the ToolCallContent that initiated this tool call.")
content: str = Field(..., description="The content returned by the tool execution.")
is_error: bool = Field(..., description="Indicates whether the tool execution resulted in an error.")
class ReasoningContent(MessageContent):
type: Literal[MessageContentType.reasoning] = Field(
MessageContentType.reasoning, description="Indicates this is a reasoning/intermediate step."
)
is_native: bool = Field(..., description="Whether the reasoning content was generated by a reasoner model that processed this step.")
reasoning: str = Field(..., description="The intermediate reasoning or thought process content.")
signature: Optional[str] = Field(None, description="A unique identifier for this reasoning step.")
class RedactedReasoningContent(MessageContent):
type: Literal[MessageContentType.redacted_reasoning] = Field(
MessageContentType.redacted_reasoning, description="Indicates this is a redacted thinking step."
)
data: str = Field(..., description="The redacted or filtered intermediate reasoning content.")
class OmittedReasoningContent(MessageContent):
type: Literal[MessageContentType.omitted_reasoning] = Field(
MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
)
tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
LettaMessageContentUnion = Annotated[
Union[TextContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent],
Field(discriminator="type"),
]
def create_letta_message_content_union_schema():
return {
"oneOf": [
{"$ref": "#/components/schemas/TextContent"},
{"$ref": "#/components/schemas/ToolCallContent"},
{"$ref": "#/components/schemas/ToolReturnContent"},
{"$ref": "#/components/schemas/ReasoningContent"},
{"$ref": "#/components/schemas/RedactedReasoningContent"},
{"$ref": "#/components/schemas/OmittedReasoningContent"},
],
"discriminator": {
"propertyName": "type",
"mapping": {
"text": "#/components/schemas/TextContent",
"tool_call": "#/components/schemas/ToolCallContent",
"tool_return": "#/components/schemas/ToolCallContent",
"reasoning": "#/components/schemas/ReasoningContent",
"redacted_reasoning": "#/components/schemas/RedactedReasoningContent",
"omitted_reasoning": "#/components/schemas/OmittedReasoningContent",
},
},
}
def get_letta_message_content_union_str_json_schema():
return {
"anyOf": [

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field, field_validator
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
@@ -27,7 +27,7 @@ from letta.schemas.letta_message import (
ToolReturnMessage,
UserMessage,
)
from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent
from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent, get_letta_message_content_union_str_json_schema
from letta.system import unpack_message
@@ -65,15 +65,30 @@ class MessageCreate(BaseModel):
MessageRole.user,
MessageRole.system,
] = Field(..., description="The role of the participant.")
content: Union[str, List[LettaMessageContentUnion]] = Field(..., description="The content of the message.")
content: Union[str, List[LettaMessageContentUnion]] = Field(
...,
description="The content of the message.",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
)
name: Optional[str] = Field(None, description="The name of the participant.")
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
if to_orm and "content" in data:
if isinstance(data["content"], str):
data["content"] = [TextContent(text=data["content"])]
return data
class MessageUpdate(BaseModel):
"""Request to update a message"""
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(None, description="The content of the message.")
content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(
None,
description="The content of the message.",
json_schema_extra=get_letta_message_content_union_str_json_schema(),
)
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
@@ -89,12 +104,7 @@ class MessageUpdate(BaseModel):
data = super().model_dump(**kwargs)
if to_orm and "content" in data:
if isinstance(data["content"], str):
data["text"] = data["content"]
else:
for content in data["content"]:
if content["type"] == "text":
data["text"] = content["text"]
del data["content"]
data["content"] = [TextContent(text=data["content"])]
return data
@@ -140,24 +150,6 @@ class Message(BaseMessage):
assert v in roles, f"Role must be one of {roles}"
return v
@model_validator(mode="before")
@classmethod
def convert_from_orm(cls, data: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(data, dict):
if "text" in data and "content" not in data:
data["content"] = [TextContent(text=data["text"])]
del data["text"]
return data
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
if to_orm:
for content in data["content"]:
if content["type"] == "text":
data["text"] = content["text"]
del data["content"]
return data
def to_json(self):
json_message = vars(self)
if json_message["tool_calls"] is not None:

View File

@@ -3,10 +3,12 @@ from marshmallow import fields
from letta.helpers.converters import (
deserialize_embedding_config,
deserialize_llm_config,
deserialize_message_content,
deserialize_tool_calls,
deserialize_tool_rules,
serialize_embedding_config,
serialize_llm_config,
serialize_message_content,
serialize_tool_calls,
serialize_tool_rules,
)
@@ -67,3 +69,13 @@ class ToolCallField(fields.Field):
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_tool_calls(value)
class MessageContentField(fields.Field):
"""Marshmallow field for handling a list of Message Content Part objects."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_message_content(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_message_content(value)

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
@@ -27,7 +28,7 @@ class MessageSchema(BaseModel):
model: Optional[str]
name: Optional[str]
role: str
text: str
content: List[TextContent] # TODO: Expand to more in the future
tool_call_id: Optional[str]
tool_calls: List[Any]
tool_returns: List[Any]

View File

@@ -17,7 +17,11 @@ from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaU
from letta.log import get_logger
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
from letta.schemas.letta_message import create_letta_message_union_schema
from letta.schemas.letta_message_content import create_letta_message_content_union_schema
from letta.schemas.letta_message_content import (
create_letta_assistant_message_content_union_schema,
create_letta_message_content_union_schema,
create_letta_user_message_content_union_schema,
)
from letta.server.constants import REST_DEFAULT_PORT
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
@@ -70,6 +74,9 @@ def generate_openapi_schema(app: FastAPI):
letta_docs["info"]["title"] = "Letta API"
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema()
letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema()
letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema()
for name, docs in [
(
"letta",

View File

@@ -740,7 +740,7 @@ class SyncServer(Server):
Message(
agent_id=agent_id,
role=message.role,
content=[TextContent(text=message.content)],
content=[TextContent(text=message.content)] if message.content else [],
name=message.name,
# assigned later?
model=None,

View File

@@ -1,7 +1,7 @@
import json
from typing import List, Optional
from sqlalchemy import and_, or_
from sqlalchemy import and_, exists, func, or_, select, text
from letta.log import get_logger
from letta.orm.agent import Agent as AgentModel
@@ -233,9 +233,17 @@ class MessageManager:
# Build a query that directly filters the Message table by agent_id.
query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id)
# If query_text is provided, filter messages by partial match on text.
# If query_text is provided, filter messages using subquery.
if query_text:
query = query.filter(MessageModel.text.ilike(f"%{query_text}%"))
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
query = query.filter(
exists(
select(1)
.select_from(content_element)
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
.params(query_text=f"%{query_text}%")
)
)
# If role is provided, filter messages by role.
if role:

View File

@@ -482,7 +482,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
agent_id = serialize_test_agent.id
# Step 1: Download the serialized agent
response = fastapi_client.get(f"/v1/agents/{agent_id}/download", headers={"user_id": default_user.id})
response = fastapi_client.get(f"/v1/agents/{agent_id}/export", headers={"user_id": default_user.id})
assert response.status_code == 200, f"Download failed: {response.text}"
# Ensure response matches expected schema
@@ -493,7 +493,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
files = {"file": ("agent.json", agent_bytes, "application/json")}
upload_response = fastapi_client.post(
"/v1/agents/upload",
"/v1/agents/import",
headers={"user_id": other_user.id},
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
files=files,

View File

@@ -25,6 +25,7 @@ from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPro
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate
@@ -272,7 +273,7 @@ def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent):
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
role="user",
text="Hello, world!",
content=[TextContent(text="Hello, world!")],
)
msg = server.message_manager.create_message(message, actor=default_user)
@@ -1196,7 +1197,7 @@ def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent,
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
role="user",
text="Hello, Sarah!",
content=[TextContent(text="Hello, Sarah!")],
),
actor=default_user,
)
@@ -1205,7 +1206,7 @@ def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent,
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
role="assistant",
text="Hello, user!",
content=[TextContent(text="Hello, user!")],
),
actor=default_user,
)
@@ -1236,7 +1237,7 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
role="user",
text="Hello, Sarah!",
content=[TextContent(text="Hello, Sarah!")],
),
actor=default_user,
)
@@ -2062,7 +2063,10 @@ def test_message_size(server: SyncServer, hello_world_message_fixture, default_u
# Create additional test messages
messages = [
PydanticMessage(
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
organization_id=default_user.organization_id,
agent_id=base_message.agent_id,
role=base_message.role,
content=[TextContent(text=f"Test message {i}")],
)
for i in range(4)
]
@@ -2090,7 +2094,10 @@ def create_test_messages(server: SyncServer, base_message: PydanticMessage, defa
"""Helper function to create test messages for all tests"""
messages = [
PydanticMessage(
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
organization_id=default_user.organization_id,
agent_id=base_message.agent_id,
role=base_message.role,
content=[TextContent(text=f"Test message {i}")],
)
for i in range(4)
]
@@ -3270,7 +3277,7 @@ def test_job_messages_pagination(server: SyncServer, default_run, default_user,
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
role=MessageRole.user,
text=f"Test message {i}",
content=[TextContent(text=f"Test message {i}")],
)
msg = server.message_manager.create_message(message, actor=default_user)
message_ids.append(msg.id)
@@ -3383,7 +3390,7 @@ def test_job_messages_ordering(server: SyncServer, default_run, default_user, sa
for i, created_at in enumerate(message_times):
message = PydanticMessage(
role=MessageRole.user,
text="Test message",
content=[TextContent(text="Test message")],
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
created_at=created_at,
@@ -3452,19 +3459,19 @@ def test_job_messages_filter(server: SyncServer, default_run, default_user, sara
messages = [
PydanticMessage(
role=MessageRole.user,
text="Hello",
content=[TextContent(text="Hello")],
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
),
PydanticMessage(
role=MessageRole.assistant,
text="Hi there!",
content=[TextContent(text="Hi there!")],
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
),
PydanticMessage(
role=MessageRole.assistant,
text="Let me help you with that",
content=[TextContent(text="Let me help you with that")],
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
tool_calls=[
@@ -3519,7 +3526,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')],
tool_calls=(
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
if i % 2 == 1
@@ -3570,7 +3577,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')],
tool_calls=(
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
if i % 2 == 1