diff --git a/alembic/versions/95badb46fdf9_migrate_message_to_orm.py b/alembic/versions/d27a33843feb_migrate_messages_to_the_orm.py similarity index 87% rename from alembic/versions/95badb46fdf9_migrate_message_to_orm.py rename to alembic/versions/d27a33843feb_migrate_messages_to_the_orm.py index 73254e39..575bd25b 100644 --- a/alembic/versions/95badb46fdf9_migrate_message_to_orm.py +++ b/alembic/versions/d27a33843feb_migrate_messages_to_the_orm.py @@ -1,8 +1,8 @@ -"""Migrate message to orm +"""Migrate messages to the orm -Revision ID: 95badb46fdf9 -Revises: 3c683a662c82 -Create Date: 2024-12-05 14:02:04.163150 +Revision ID: d27a33843feb +Revises: 08b2f8225812 +Create Date: 2024-12-07 13:52:20.591898 """ @@ -14,7 +14,7 @@ from sqlalchemy.dialects import postgresql from alembic import op # revision identifiers, used by Alembic. -revision: str = "95badb46fdf9" +revision: str = "d27a33843feb" down_revision: Union[str, None] = "08b2f8225812" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -39,10 +39,9 @@ def upgrade() -> None: ) op.alter_column("messages", "organization_id", nullable=False) op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) - op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) op.drop_index("message_idx_user", table_name="messages") - op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"]) op.create_foreign_key(None, "messages", "organizations", ["organization_id"], ["id"]) + op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"]) op.drop_column("messages", "user_id") # ### end Alembic commands ### @@ -53,7 +52,6 @@ def downgrade() -> None: op.drop_constraint(None, "messages", type_="foreignkey") op.drop_constraint(None, "messages", type_="foreignkey") op.create_index("message_idx_user", "messages", ["user_id", "agent_id"], unique=False) - op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True) op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) op.drop_column("messages", "organization_id") op.drop_column("messages", "_last_updated_by_id") diff --git a/letta/agent.py b/letta/agent.py index 3e619ea5..12368a93 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -36,7 +36,7 @@ from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.memory import ContextWindowOverview, Memory -from letta.schemas.message import Message, UpdateMessage +from letta.schemas.message import Message, MessageUpdate from letta.schemas.openai.chat_completion_request import ( Tool as ChatCompletionRequestTool, ) @@ -512,9 +512,10 @@ class Agent(BaseAgent): for m in self._messages: # assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}" # TODO eventually do casting via an edit_message function - if not is_utc_datetime(m.created_at): - printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") - m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc) + if m.created_at: + if not is_utc_datetime(m.created_at): + printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") + m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc) def set_message_buffer(self, message_ids: List[str], force_utc: bool = True): """Set the messages in the buffer to the message IDs list""" @@ -1405,36 +1406,10 @@ class Agent(BaseAgent): f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.", ) - def update_message(self, request: UpdateMessage) -> Message: + def update_message(self, message_id: str, request: MessageUpdate) -> Message: """Update the details of a message associated with an agent""" - - message = self.message_manager.get_message_by_id(message_id=request.id, actor=self.user) - if message is None: - raise ValueError(f"Message with id {request.id} not found") - assert isinstance(message, Message), f"Message is not a Message object: {type(message)}" - - # Override fields - # NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof - if request.role: - message.role = request.role - if request.text: - message.text = request.text - if request.name: - message.name = request.name - if request.tool_calls: - assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages" - message.tool_calls = request.tool_calls - if request.tool_call_id: - assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages" - message.tool_call_id = request.tool_call_id - # Save the updated message - self.message_manager.update_message_by_id(message_id=message.id, message=message, actor=self.user) - - # Return the updated message - updated_message = self.message_manager.get_message_by_id(message_id=message.id, actor=self.user) - if updated_message is None: - raise ValueError(f"Error persisting message - message with id {request.id} not found") + updated_message = self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=self.user) return updated_message # TODO(sarah): should we be creating a new message here, or just editing a message? @@ -1444,10 +1419,10 @@ class Agent(BaseAgent): msg_obj = self._messages[x] if msg_obj.role == MessageRole.assistant: updated_message = self.update_message( - request=UpdateMessage( - id=msg_obj.id, + message_id=msg_obj.id, + request=MessageUpdate( text=new_thought, - ) + ), ) self.refresh_message_buffer() return updated_message @@ -1486,10 +1461,10 @@ class Agent(BaseAgent): # Write the update to the DB updated_message = self.update_message( - request=UpdateMessage( - id=message_obj.id, + message_id=message_obj.id, + request=MessageUpdate( tool_calls=message_obj.tool_calls, - ) + ), ) self.refresh_message_buffer() return updated_message diff --git a/letta/client/client.py b/letta/client/client.py index eaf477e9..afccdafa 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -32,7 +32,7 @@ from letta.schemas.memory import ( Memory, RecallMemorySummary, ) -from letta.schemas.message import Message, MessageCreate, UpdateMessage +from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.openai.chat_completions import ToolCall from letta.schemas.organization import Organization from letta.schemas.passage import Passage @@ -586,8 +586,7 @@ class RESTClient(AbstractClient): tool_calls: Optional[List[ToolCall]] = None, tool_call_id: Optional[str] = None, ) -> Message: - request = UpdateMessage( - id=message_id, + request = MessageUpdate( role=role, text=text, name=name, @@ -2148,8 +2147,8 @@ class LocalClient(AbstractClient): ) -> Message: message = self.server.update_agent_message( agent_id=agent_id, - request=UpdateMessage( - id=message_id, + message_id=message_id, + request=MessageUpdate( role=role, text=text, name=name, diff --git a/letta/orm/message.py b/letta/orm/message.py index 3f0b56c7..8de6f1f5 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -1,7 +1,6 @@ -from datetime import datetime from typing import Optional -from sqlalchemy import JSON, DateTime, TypeDecorator +from sqlalchemy import JSON, TypeDecorator from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import AgentMixin, OrganizationMixin @@ -58,7 +57,6 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios") tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information") tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call") - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow) # Relationships # TODO: Add in after Agent ORM is created diff --git a/letta/schemas/message.py b/letta/schemas/message.py index a9e2fcb8..750e4a05 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -13,7 +13,7 @@ from letta.constants import ( ) from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.schemas.enums import MessageRole -from letta.schemas.letta_base import LettaBase +from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import ( AssistantMessage, FunctionCall, @@ -50,7 +50,7 @@ def add_inner_thoughts_to_tool_call( raise e -class BaseMessage(LettaBase): +class BaseMessage(OrmMetadataBase): __id_prefix__ = "message" @@ -66,10 +66,9 @@ class MessageCreate(BaseMessage): name: Optional[str] = Field(None, description="The name of the participant.") -class UpdateMessage(BaseMessage): +class MessageUpdate(BaseMessage): """Request to update a message""" - id: str = Field(..., description="The id of the message.") role: Optional[MessageRole] = Field(None, description="The role of the participant.") text: Optional[str] = Field(None, description="The text of the message.") # NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message) @@ -109,9 +108,10 @@ class Message(BaseMessage): agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") model: Optional[str] = Field(None, description="The model used to make the function call.") name: Optional[str] = Field(None, description="The name of the participant.") - created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.") tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.") tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") + # This overrides the optional base orm schema, created_at MUST exist on all messages objects + created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") @field_validator("role") @classmethod diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index ebf10b7c..e3922a28 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -28,7 +28,7 @@ from letta.schemas.memory import ( Memory, RecallMemorySummary, ) -from letta.schemas.message import Message, MessageCreate, UpdateMessage +from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.passage import Passage from letta.schemas.source import Source from letta.schemas.tool import Tool @@ -422,14 +422,13 @@ def get_agent_messages( def update_message( agent_id: str, message_id: str, - request: UpdateMessage = Body(...), + request: MessageUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), ): """ Update the details of a message associated with an agent. """ - assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}" - return server.update_agent_message(agent_id=agent_id, request=request) + return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request) @router.post( diff --git a/letta/server/server.py b/letta/server/server.py index fa3d29fa..da87d9df 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -67,7 +67,7 @@ from letta.schemas.memory import ( Memory, RecallMemorySummary, ) -from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage +from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate from letta.schemas.organization import Organization from letta.schemas.passage import Passage from letta.schemas.source import Source @@ -1662,12 +1662,12 @@ class SyncServer(Server): save_agent(letta_agent, self.ms) return message - def update_agent_message(self, agent_id: str, request: UpdateMessage) -> Message: + def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate) -> Message: """Update the details of a message associated with an agent""" # Get the current message letta_agent = self.load_agent(agent_id=agent_id) - response = letta_agent.update_message(request=request) + response = letta_agent.update_message(message_id=message_id, request=request) save_agent(letta_agent, self.ms) return response diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 7a46ddba..b9932b39 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -5,6 +5,7 @@ from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel from letta.schemas.enums import MessageRole from letta.schemas.message import Message as PydanticMessage +from letta.schemas.message import MessageUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types @@ -44,27 +45,38 @@ class MessageManager: return [self.create_message(m, actor=actor) for m in pydantic_msgs] @enforce_types - def update_message_by_id(self, message_id: str, message: PydanticMessage, actor: PydanticUser) -> PydanticMessage: + def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage: """ Updates an existing record in the database with values from the provided record object. """ with self.session_maker() as session: # Fetch existing message from database - msg = MessageModel.read( + message = MessageModel.read( db_session=session, identifier=message_id, actor=actor, ) - # Update the database record with values from the provided record - for column in MessageModel.__table__.columns: - column_name = column.name - if hasattr(message, column_name): - new_value = getattr(message, column_name) - setattr(msg, column_name, new_value) + # Some safety checks specific to messages + if message_update.tool_calls and message.role != MessageRole.assistant: + raise ValueError( + f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}." + ) + if message_update.tool_call_id and message.role != MessageRole.tool: + raise ValueError( + f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}." + ) - # Commit changes - return msg.update(db_session=session, actor=actor).to_pydantic() + # get update dictionary + update_data = message_update.model_dump(exclude_unset=True, exclude_none=True) + # Remove redundant update fields + update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value} + + for key, value in update_data.items(): + setattr(message, key, value) + message.update(db_session=session, actor=actor) + + return message.to_pydantic() @enforce_types def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool: diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index ce93457b..65bd5f16 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -279,11 +279,7 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent assert chunk.total_tokens > 1000 # If stream tokens, we expect at least one inner thought - if stream_tokens: - assert inner_thoughts_count > 1, "Expected more than one inner thought" - else: - assert inner_thoughts_count == 1, "Expected one inner thought" - + assert inner_thoughts_count >= 1, "Expected more than one inner thought" assert inner_thoughts_exist, "No inner thoughts found" assert send_message_ran, "send_message function call not found" assert done, "Message stream not done" diff --git a/tests/test_managers.py b/tests/test_managers.py index 53107dfe..dc476938 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -37,6 +37,7 @@ from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage +from letta.schemas.message import MessageUpdate from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.sandbox_config import ( E2BSandboxConfig, @@ -598,16 +599,19 @@ def test_message_get_by_id(server: SyncServer, hello_world_message_fixture, defa assert retrieved.text == hello_world_message_fixture.text -def test_message_update(server: SyncServer, hello_world_message_fixture, default_user): +def test_message_update(server: SyncServer, hello_world_message_fixture, default_user, other_user): """Test updating a message""" new_text = "Updated text" - hello_world_message_fixture.text = new_text - updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, hello_world_message_fixture, actor=default_user) + updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(text=new_text), actor=other_user) assert updated is not None assert updated.text == new_text retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user) assert retrieved.text == new_text + # Assert that orm metadata fields are populated + assert retrieved.created_by_id == default_user.id + assert retrieved.last_updated_by_id == other_user.id + def test_message_delete(server: SyncServer, hello_world_message_fixture, default_user): """Test deleting a message"""