From eaf9af3d03d40be0306ef8b37f5a4f6d35cffdd1 Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 14 Apr 2025 16:50:37 -0700 Subject: [PATCH] feat: add identity id to message model (#1700) --- .../c3b1da3d1157_add_sender_id_to_message.py | 31 +++++++++++++++++++ letta/helpers/message_helper.py | 1 + letta/orm/message.py | 3 ++ letta/schemas/letta_message.py | 31 ++++++++++++++----- letta/schemas/message.py | 10 ++++++ letta/server/rest_api/utils.py | 1 + 6 files changed, 70 insertions(+), 7 deletions(-) create mode 100644 alembic/versions/c3b1da3d1157_add_sender_id_to_message.py diff --git a/alembic/versions/c3b1da3d1157_add_sender_id_to_message.py b/alembic/versions/c3b1da3d1157_add_sender_id_to_message.py new file mode 100644 index 00000000..bd59a118 --- /dev/null +++ b/alembic/versions/c3b1da3d1157_add_sender_id_to_message.py @@ -0,0 +1,31 @@ +"""add sender id to message + +Revision ID: c3b1da3d1157 +Revises: 0ceb975e0063 +Create Date: 2025-04-14 08:53:14.548061 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c3b1da3d1157" +down_revision: Union[str, None] = "0ceb975e0063" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("messages", sa.Column("sender_id", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("messages", "sender_id") + # ### end Alembic commands ### diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index 5f040ced..5f5b6c04 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -39,4 +39,5 @@ def prepare_input_message_create( tool_calls=None, # irrelevant tool_call_id=None, otid=message.otid, + sender_id=message.sender_id, ) diff --git a/letta/orm/message.py b/letta/orm/message.py index 589bd2d1..9f678bb1 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -41,6 +41,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): ToolReturnColumn, nullable=True, doc="Tool execution return information for prior tool calls" ) group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in") + sender_id: Mapped[Optional[str]] = mapped_column( + nullable=True, doc="The id of the sender of the message, can be an identity id or agent id" + ) # Monotonically increasing sequence for efficient/correct listing sequence_id = mapped_column(BigInteger, Sequence("message_seq_id"), unique=True, nullable=False) diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index ec58d8c6..1b4e8994 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -1,5 +1,6 @@ import json from datetime import datetime, timezone +from enum import Enum from typing import Annotated, List, Literal, Optional, Union from pydantic import BaseModel, Field, field_serializer, field_validator @@ -16,6 +17,16 @@ from letta.schemas.letta_message_content import ( # --------------------------- +class MessageType(str, Enum): + system_message = "system_message" + user_message = "user_message" + assistant_message = "assistant_message" + reasoning_message = "reasoning_message" + hidden_reasoning_message = "hidden_reasoning_message" + tool_call_message = "tool_call_message" + tool_return_message = "tool_return_message" + + class LettaMessage(BaseModel): """ Base class for simplified Letta message response type. This is intended to be used for developers @@ -26,13 +37,17 @@ class LettaMessage(BaseModel): id (str): The ID of the message date (datetime): The date the message was created in ISO format name (Optional[str]): The name of the sender of the message + message_type (MessageType): The type of the message otid (Optional[str]): The offline threading id associated with this message + sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id """ id: str date: datetime name: Optional[str] = None + message_type: MessageType = Field(..., description="The type of the message.") otid: Optional[str] = None + sender_id: Optional[str] = None @field_serializer("date") def serialize_datetime(self, dt: datetime, _info): @@ -56,7 +71,7 @@ class SystemMessage(LettaMessage): content (str): The message content sent by the system """ - message_type: Literal["system_message"] = "system_message" + message_type: Literal[MessageType.system_message] = Field(MessageType.system_message, description="The type of the message.") content: str = Field(..., description="The message content sent by the system") @@ -71,7 +86,7 @@ class UserMessage(LettaMessage): content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts) """ - message_type: Literal["user_message"] = "user_message" + message_type: Literal[MessageType.user_message] = Field(MessageType.user_message, description="The type of the message.") content: Union[str, List[LettaUserMessageContentUnion]] = Field( ..., description="The message content sent by the user (can be a string or an array of multi-modal content parts)", @@ -93,7 +108,7 @@ class ReasoningMessage(LettaMessage): signature (Optional[str]): The model-generated signature of the reasoning step """ - message_type: Literal["reasoning_message"] = "reasoning_message" + message_type: Literal[MessageType.reasoning_message] = Field(MessageType.reasoning_message, description="The type of the message.") source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model" reasoning: str signature: Optional[str] = None @@ -113,7 +128,9 @@ class HiddenReasoningMessage(LettaMessage): hidden_reasoning (Optional[str]): The internal reasoning of the agent """ - message_type: Literal["hidden_reasoning_message"] = "hidden_reasoning_message" + message_type: Literal[MessageType.hidden_reasoning_message] = Field( + MessageType.hidden_reasoning_message, description="The type of the message." + ) state: Literal["redacted", "omitted"] hidden_reasoning: Optional[str] = None @@ -152,7 +169,7 @@ class ToolCallMessage(LettaMessage): tool_call (Union[ToolCall, ToolCallDelta]): The tool call """ - message_type: Literal["tool_call_message"] = "tool_call_message" + message_type: Literal[MessageType.tool_call_message] = Field(MessageType.tool_call_message, description="The type of the message.") tool_call: Union[ToolCall, ToolCallDelta] def model_dump(self, *args, **kwargs): @@ -204,7 +221,7 @@ class ToolReturnMessage(LettaMessage): stderr (Optional[List(str)]): Captured stderr from the tool invocation """ - message_type: Literal["tool_return_message"] = "tool_return_message" + message_type: Literal[MessageType.tool_return_message] = Field(MessageType.tool_return_message, description="The type of the message.") tool_return: str status: Literal["success", "error"] tool_call_id: str @@ -223,7 +240,7 @@ class AssistantMessage(LettaMessage): content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts) """ - message_type: Literal["assistant_message"] = "assistant_message" + message_type: Literal[MessageType.assistant_message] = Field(MessageType.assistant_message, description="The type of the message.") content: Union[str, List[LettaAssistantMessageContentUnion]] = Field( ..., description="The message content sent by the agent (can be a string or an array of content parts)", diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 0cb79238..4eb45f89 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -81,6 +81,7 @@ class MessageCreate(BaseModel): ) name: Optional[str] = Field(None, description="The name of the participant.") otid: Optional[str] = Field(None, description="The offline threading id associated with this message") + sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id") def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]: data = super().model_dump(**kwargs) @@ -157,6 +158,7 @@ class Message(BaseMessage): otid: Optional[str] = Field(None, description="The offline threading id associated with this message") tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls") group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in") + sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id") # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") @@ -246,6 +248,7 @@ class Message(BaseMessage): reasoning=self.content[0].text, name=self.name, otid=otid, + sender_id=self.sender_id, ) ) # Otherwise, we may have a list of multiple types @@ -262,6 +265,7 @@ class Message(BaseMessage): reasoning=content_part.text, name=self.name, otid=otid, + sender_id=self.sender_id, ) ) elif isinstance(content_part, ReasoningContent): @@ -287,6 +291,7 @@ class Message(BaseMessage): hidden_reasoning=content_part.data, name=self.name, otid=otid, + sender_id=self.sender_id, ) ) else: @@ -312,6 +317,7 @@ class Message(BaseMessage): content=message_string, name=self.name, otid=otid, + sender_id=self.sender_id, ) ) else: @@ -326,6 +332,7 @@ class Message(BaseMessage): ), name=self.name, otid=otid, + sender_id=self.sender_id, ) ) elif self.role == MessageRole.tool: @@ -368,6 +375,7 @@ class Message(BaseMessage): stderr=self.tool_returns[0].stderr if self.tool_returns else None, name=self.name, otid=self.id.replace("message-", ""), + sender_id=self.sender_id, ) ) elif self.role == MessageRole.user: @@ -385,6 +393,7 @@ class Message(BaseMessage): content=message_str or text_content, name=self.name, otid=self.otid, + sender_id=self.sender_id, ) ) elif self.role == MessageRole.system: @@ -401,6 +410,7 @@ class Message(BaseMessage): content=text_content, name=self.name, otid=self.otid, + sender_id=self.sender_id, ) ) else: diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index af944ec1..2daa5d3e 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -153,6 +153,7 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ac content=input_message.content, name=input_message.name, otid=input_message.otid, + sender_id=input_message.sender_id, organization_id=actor.organization_id, agent_id=agent_id, model=None,