fix: Updating messages (#2186)

This commit is contained in:
Matthew Zhou
2024-12-07 14:09:20 -08:00
committed by GitHub
parent 79cc78f5cb
commit 1f57569116
10 changed files with 65 additions and 84 deletions

View File

@@ -1,8 +1,8 @@
"""Migrate message to orm
"""Migrate messages to the orm
Revision ID: 95badb46fdf9
Revises: 3c683a662c82
Create Date: 2024-12-05 14:02:04.163150
Revision ID: d27a33843feb
Revises: 08b2f8225812
Create Date: 2024-12-07 13:52:20.591898
"""
@@ -14,7 +14,7 @@ from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "95badb46fdf9"
revision: str = "d27a33843feb"
down_revision: Union[str, None] = "08b2f8225812"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -39,10 +39,9 @@ def upgrade() -> None:
)
op.alter_column("messages", "organization_id", nullable=False)
op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
op.drop_index("message_idx_user", table_name="messages")
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"])
op.create_foreign_key(None, "messages", "organizations", ["organization_id"], ["id"])
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"])
op.drop_column("messages", "user_id")
# ### end Alembic commands ###
@@ -53,7 +52,6 @@ def downgrade() -> None:
op.drop_constraint(None, "messages", type_="foreignkey")
op.drop_constraint(None, "messages", type_="foreignkey")
op.create_index("message_idx_user", "messages", ["user_id", "agent_id"], unique=False)
op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True)
op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.drop_column("messages", "organization_id")
op.drop_column("messages", "_last_updated_by_id")

View File

@@ -36,7 +36,7 @@ from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, UpdateMessage
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.openai.chat_completion_request import (
Tool as ChatCompletionRequestTool,
)
@@ -512,9 +512,10 @@ class Agent(BaseAgent):
for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
if m.created_at:
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
"""Set the messages in the buffer to the message IDs list"""
@@ -1405,36 +1406,10 @@ class Agent(BaseAgent):
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
)
def update_message(self, request: UpdateMessage) -> Message:
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
"""Update the details of a message associated with an agent"""
message = self.message_manager.get_message_by_id(message_id=request.id, actor=self.user)
if message is None:
raise ValueError(f"Message with id {request.id} not found")
assert isinstance(message, Message), f"Message is not a Message object: {type(message)}"
# Override fields
# NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof
if request.role:
message.role = request.role
if request.text:
message.text = request.text
if request.name:
message.name = request.name
if request.tool_calls:
assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages"
message.tool_calls = request.tool_calls
if request.tool_call_id:
assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages"
message.tool_call_id = request.tool_call_id
# Save the updated message
self.message_manager.update_message_by_id(message_id=message.id, message=message, actor=self.user)
# Return the updated message
updated_message = self.message_manager.get_message_by_id(message_id=message.id, actor=self.user)
if updated_message is None:
raise ValueError(f"Error persisting message - message with id {request.id} not found")
updated_message = self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=self.user)
return updated_message
# TODO(sarah): should we be creating a new message here, or just editing a message?
@@ -1444,10 +1419,10 @@ class Agent(BaseAgent):
msg_obj = self._messages[x]
if msg_obj.role == MessageRole.assistant:
updated_message = self.update_message(
request=UpdateMessage(
id=msg_obj.id,
message_id=msg_obj.id,
request=MessageUpdate(
text=new_thought,
)
),
)
self.refresh_message_buffer()
return updated_message
@@ -1486,10 +1461,10 @@ class Agent(BaseAgent):
# Write the update to the DB
updated_message = self.update_message(
request=UpdateMessage(
id=message_obj.id,
message_id=message_obj.id,
request=MessageUpdate(
tool_calls=message_obj.tool_calls,
)
),
)
self.refresh_message_buffer()
return updated_message

View File

@@ -32,7 +32,7 @@ from letta.schemas.memory import (
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, UpdateMessage
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completions import ToolCall
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
@@ -586,8 +586,7 @@ class RESTClient(AbstractClient):
tool_calls: Optional[List[ToolCall]] = None,
tool_call_id: Optional[str] = None,
) -> Message:
request = UpdateMessage(
id=message_id,
request = MessageUpdate(
role=role,
text=text,
name=name,
@@ -2148,8 +2147,8 @@ class LocalClient(AbstractClient):
) -> Message:
message = self.server.update_agent_message(
agent_id=agent_id,
request=UpdateMessage(
id=message_id,
message_id=message_id,
request=MessageUpdate(
role=role,
text=text,
name=name,

View File

@@ -1,7 +1,6 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import JSON, DateTime, TypeDecorator
from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import AgentMixin, OrganizationMixin
@@ -58,7 +57,6 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
# Relationships
# TODO: Add in after Agent ORM is created

View File

@@ -13,7 +13,7 @@ from letta.constants import (
)
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageRole
from letta.schemas.letta_base import LettaBase
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCall,
@@ -50,7 +50,7 @@ def add_inner_thoughts_to_tool_call(
raise e
class BaseMessage(LettaBase):
class BaseMessage(OrmMetadataBase):
__id_prefix__ = "message"
@@ -66,10 +66,9 @@ class MessageCreate(BaseMessage):
name: Optional[str] = Field(None, description="The name of the participant.")
class UpdateMessage(BaseMessage):
class MessageUpdate(BaseMessage):
"""Request to update a message"""
id: str = Field(..., description="The id of the message.")
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
text: Optional[str] = Field(None, description="The text of the message.")
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
@@ -109,9 +108,10 @@ class Message(BaseMessage):
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
model: Optional[str] = Field(None, description="The model used to make the function call.")
name: Optional[str] = Field(None, description="The name of the participant.")
created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.")
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@field_validator("role")
@classmethod

View File

@@ -28,7 +28,7 @@ from letta.schemas.memory import (
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, UpdateMessage
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.schemas.tool import Tool
@@ -422,14 +422,13 @@ def get_agent_messages(
def update_message(
agent_id: str,
message_id: str,
request: UpdateMessage = Body(...),
request: MessageUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Update the details of a message associated with an agent.
"""
assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}"
return server.update_agent_message(agent_id=agent_id, request=request)
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request)
@router.post(

View File

@@ -67,7 +67,7 @@ from letta.schemas.memory import (
Memory,
RecallMemorySummary,
)
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.source import Source
@@ -1662,12 +1662,12 @@ class SyncServer(Server):
save_agent(letta_agent, self.ms)
return message
def update_agent_message(self, agent_id: str, request: UpdateMessage) -> Message:
def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate) -> Message:
"""Update the details of a message associated with an agent"""
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id)
response = letta_agent.update_message(request=request)
response = letta_agent.update_message(message_id=message_id, request=request)
save_agent(letta_agent, self.ms)
return response

View File

@@ -5,6 +5,7 @@ from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.schemas.enums import MessageRole
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
@@ -44,27 +45,38 @@ class MessageManager:
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
@enforce_types
def update_message_by_id(self, message_id: str, message: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
"""
Updates an existing record in the database with values from the provided record object.
"""
with self.session_maker() as session:
# Fetch existing message from database
msg = MessageModel.read(
message = MessageModel.read(
db_session=session,
identifier=message_id,
actor=actor,
)
# Update the database record with values from the provided record
for column in MessageModel.__table__.columns:
column_name = column.name
if hasattr(message, column_name):
new_value = getattr(message, column_name)
setattr(msg, column_name, new_value)
# Some safety checks specific to messages
if message_update.tool_calls and message.role != MessageRole.assistant:
raise ValueError(
f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}."
)
if message_update.tool_call_id and message.role != MessageRole.tool:
raise ValueError(
f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}."
)
# Commit changes
return msg.update(db_session=session, actor=actor).to_pydantic()
# get update dictionary
update_data = message_update.model_dump(exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
for key, value in update_data.items():
setattr(message, key, value)
message.update(db_session=session, actor=actor)
return message.to_pydantic()
@enforce_types
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:

View File

@@ -279,11 +279,7 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
assert chunk.total_tokens > 1000
# If stream tokens, we expect at least one inner thought
if stream_tokens:
assert inner_thoughts_count > 1, "Expected more than one inner thought"
else:
assert inner_thoughts_count == 1, "Expected one inner thought"
assert inner_thoughts_count >= 1, "Expected more than one inner thought"
assert inner_thoughts_exist, "No inner thoughts found"
assert send_message_ran, "send_message function call not found"
assert done, "Message stream not done"

View File

@@ -37,6 +37,7 @@ from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageUpdate
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.sandbox_config import (
E2BSandboxConfig,
@@ -598,16 +599,19 @@ def test_message_get_by_id(server: SyncServer, hello_world_message_fixture, defa
assert retrieved.text == hello_world_message_fixture.text
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user):
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user, other_user):
"""Test updating a message"""
new_text = "Updated text"
hello_world_message_fixture.text = new_text
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, hello_world_message_fixture, actor=default_user)
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(text=new_text), actor=other_user)
assert updated is not None
assert updated.text == new_text
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
assert retrieved.text == new_text
# Assert that orm metadata fields are populated
assert retrieved.created_by_id == default_user.id
assert retrieved.last_updated_by_id == other_user.id
def test_message_delete(server: SyncServer, hello_world_message_fixture, default_user):
"""Test deleting a message"""