fix: Updating messages (#2186)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
"""Migrate message to orm
|
||||
"""Migrate messages to the orm
|
||||
|
||||
Revision ID: 95badb46fdf9
|
||||
Revises: 3c683a662c82
|
||||
Create Date: 2024-12-05 14:02:04.163150
|
||||
Revision ID: d27a33843feb
|
||||
Revises: 08b2f8225812
|
||||
Create Date: 2024-12-07 13:52:20.591898
|
||||
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,7 @@ from sqlalchemy.dialects import postgresql
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "95badb46fdf9"
|
||||
revision: str = "d27a33843feb"
|
||||
down_revision: Union[str, None] = "08b2f8225812"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@@ -39,10 +39,9 @@ def upgrade() -> None:
|
||||
)
|
||||
op.alter_column("messages", "organization_id", nullable=False)
|
||||
op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
||||
op.drop_index("message_idx_user", table_name="messages")
|
||||
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"])
|
||||
op.create_foreign_key(None, "messages", "organizations", ["organization_id"], ["id"])
|
||||
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"])
|
||||
op.drop_column("messages", "user_id")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -53,7 +52,6 @@ def downgrade() -> None:
|
||||
op.drop_constraint(None, "messages", type_="foreignkey")
|
||||
op.drop_constraint(None, "messages", type_="foreignkey")
|
||||
op.create_index("message_idx_user", "messages", ["user_id", "agent_id"], unique=False)
|
||||
op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True)
|
||||
op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.drop_column("messages", "organization_id")
|
||||
op.drop_column("messages", "_last_updated_by_id")
|
||||
@@ -36,7 +36,7 @@ from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, UpdateMessage
|
||||
from letta.schemas.message import Message, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
Tool as ChatCompletionRequestTool,
|
||||
)
|
||||
@@ -512,9 +512,10 @@ class Agent(BaseAgent):
|
||||
for m in self._messages:
|
||||
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
|
||||
# TODO eventually do casting via an edit_message function
|
||||
if not is_utc_datetime(m.created_at):
|
||||
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
if m.created_at:
|
||||
if not is_utc_datetime(m.created_at):
|
||||
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
|
||||
"""Set the messages in the buffer to the message IDs list"""
|
||||
@@ -1405,36 +1406,10 @@ class Agent(BaseAgent):
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
||||
)
|
||||
|
||||
def update_message(self, request: UpdateMessage) -> Message:
|
||||
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
|
||||
"""Update the details of a message associated with an agent"""
|
||||
|
||||
message = self.message_manager.get_message_by_id(message_id=request.id, actor=self.user)
|
||||
if message is None:
|
||||
raise ValueError(f"Message with id {request.id} not found")
|
||||
assert isinstance(message, Message), f"Message is not a Message object: {type(message)}"
|
||||
|
||||
# Override fields
|
||||
# NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof
|
||||
if request.role:
|
||||
message.role = request.role
|
||||
if request.text:
|
||||
message.text = request.text
|
||||
if request.name:
|
||||
message.name = request.name
|
||||
if request.tool_calls:
|
||||
assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages"
|
||||
message.tool_calls = request.tool_calls
|
||||
if request.tool_call_id:
|
||||
assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages"
|
||||
message.tool_call_id = request.tool_call_id
|
||||
|
||||
# Save the updated message
|
||||
self.message_manager.update_message_by_id(message_id=message.id, message=message, actor=self.user)
|
||||
|
||||
# Return the updated message
|
||||
updated_message = self.message_manager.get_message_by_id(message_id=message.id, actor=self.user)
|
||||
if updated_message is None:
|
||||
raise ValueError(f"Error persisting message - message with id {request.id} not found")
|
||||
updated_message = self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=self.user)
|
||||
return updated_message
|
||||
|
||||
# TODO(sarah): should we be creating a new message here, or just editing a message?
|
||||
@@ -1444,10 +1419,10 @@ class Agent(BaseAgent):
|
||||
msg_obj = self._messages[x]
|
||||
if msg_obj.role == MessageRole.assistant:
|
||||
updated_message = self.update_message(
|
||||
request=UpdateMessage(
|
||||
id=msg_obj.id,
|
||||
message_id=msg_obj.id,
|
||||
request=MessageUpdate(
|
||||
text=new_thought,
|
||||
)
|
||||
),
|
||||
)
|
||||
self.refresh_message_buffer()
|
||||
return updated_message
|
||||
@@ -1486,10 +1461,10 @@ class Agent(BaseAgent):
|
||||
|
||||
# Write the update to the DB
|
||||
updated_message = self.update_message(
|
||||
request=UpdateMessage(
|
||||
id=message_obj.id,
|
||||
message_id=message_obj.id,
|
||||
request=MessageUpdate(
|
||||
tool_calls=message_obj.tool_calls,
|
||||
)
|
||||
),
|
||||
)
|
||||
self.refresh_message_buffer()
|
||||
return updated_message
|
||||
|
||||
@@ -32,7 +32,7 @@ from letta.schemas.memory import (
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, UpdateMessage
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
@@ -586,8 +586,7 @@ class RESTClient(AbstractClient):
|
||||
tool_calls: Optional[List[ToolCall]] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> Message:
|
||||
request = UpdateMessage(
|
||||
id=message_id,
|
||||
request = MessageUpdate(
|
||||
role=role,
|
||||
text=text,
|
||||
name=name,
|
||||
@@ -2148,8 +2147,8 @@ class LocalClient(AbstractClient):
|
||||
) -> Message:
|
||||
message = self.server.update_agent_message(
|
||||
agent_id=agent_id,
|
||||
request=UpdateMessage(
|
||||
id=message_id,
|
||||
message_id=message_id,
|
||||
request=MessageUpdate(
|
||||
role=role,
|
||||
text=text,
|
||||
name=name,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, DateTime, TypeDecorator
|
||||
from sqlalchemy import JSON, TypeDecorator
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
@@ -58,7 +57,6 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
|
||||
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
|
||||
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
# TODO: Add in after Agent ORM is created
|
||||
|
||||
@@ -13,7 +13,7 @@ from letta.constants import (
|
||||
)
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
FunctionCall,
|
||||
@@ -50,7 +50,7 @@ def add_inner_thoughts_to_tool_call(
|
||||
raise e
|
||||
|
||||
|
||||
class BaseMessage(LettaBase):
|
||||
class BaseMessage(OrmMetadataBase):
|
||||
__id_prefix__ = "message"
|
||||
|
||||
|
||||
@@ -66,10 +66,9 @@ class MessageCreate(BaseMessage):
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
|
||||
|
||||
class UpdateMessage(BaseMessage):
|
||||
class MessageUpdate(BaseMessage):
|
||||
"""Request to update a message"""
|
||||
|
||||
id: str = Field(..., description="The id of the message.")
|
||||
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
||||
text: Optional[str] = Field(None, description="The text of the message.")
|
||||
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
||||
@@ -109,9 +108,10 @@ class Message(BaseMessage):
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.")
|
||||
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
|
||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
|
||||
@@ -28,7 +28,7 @@ from letta.schemas.memory import (
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, UpdateMessage
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
@@ -422,14 +422,13 @@ def get_agent_messages(
|
||||
def update_message(
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
request: UpdateMessage = Body(...),
|
||||
request: MessageUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}"
|
||||
return server.update_agent_message(agent_id=agent_id, request=request)
|
||||
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -67,7 +67,7 @@ from letta.schemas.memory import (
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source
|
||||
@@ -1662,12 +1662,12 @@ class SyncServer(Server):
|
||||
save_agent(letta_agent, self.ms)
|
||||
return message
|
||||
|
||||
def update_agent_message(self, agent_id: str, request: UpdateMessage) -> Message:
|
||||
def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate) -> Message:
|
||||
"""Update the details of a message associated with an agent"""
|
||||
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
response = letta_agent.update_message(request=request)
|
||||
response = letta_agent.update_message(message_id=message_id, request=request)
|
||||
save_agent(letta_agent, self.ms)
|
||||
return response
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@@ -44,27 +45,38 @@ class MessageManager:
|
||||
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
|
||||
|
||||
@enforce_types
|
||||
def update_message_by_id(self, message_id: str, message: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
|
||||
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Fetch existing message from database
|
||||
msg = MessageModel.read(
|
||||
message = MessageModel.read(
|
||||
db_session=session,
|
||||
identifier=message_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
for column in MessageModel.__table__.columns:
|
||||
column_name = column.name
|
||||
if hasattr(message, column_name):
|
||||
new_value = getattr(message, column_name)
|
||||
setattr(msg, column_name, new_value)
|
||||
# Some safety checks specific to messages
|
||||
if message_update.tool_calls and message.role != MessageRole.assistant:
|
||||
raise ValueError(
|
||||
f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}."
|
||||
)
|
||||
if message_update.tool_call_id and message.role != MessageRole.tool:
|
||||
raise ValueError(
|
||||
f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}."
|
||||
)
|
||||
|
||||
# Commit changes
|
||||
return msg.update(db_session=session, actor=actor).to_pydantic()
|
||||
# get update dictionary
|
||||
update_data = message_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
# Remove redundant update fields
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(message, key, value)
|
||||
message.update(db_session=session, actor=actor)
|
||||
|
||||
return message.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:
|
||||
|
||||
@@ -279,11 +279,7 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
|
||||
assert chunk.total_tokens > 1000
|
||||
|
||||
# If stream tokens, we expect at least one inner thought
|
||||
if stream_tokens:
|
||||
assert inner_thoughts_count > 1, "Expected more than one inner thought"
|
||||
else:
|
||||
assert inner_thoughts_count == 1, "Expected one inner thought"
|
||||
|
||||
assert inner_thoughts_count >= 1, "Expected more than one inner thought"
|
||||
assert inner_thoughts_exist, "No inner thoughts found"
|
||||
assert send_message_ran, "send_message function call not found"
|
||||
assert done, "Message stream not done"
|
||||
|
||||
@@ -37,6 +37,7 @@ from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
@@ -598,16 +599,19 @@ def test_message_get_by_id(server: SyncServer, hello_world_message_fixture, defa
|
||||
assert retrieved.text == hello_world_message_fixture.text
|
||||
|
||||
|
||||
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user):
|
||||
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user, other_user):
|
||||
"""Test updating a message"""
|
||||
new_text = "Updated text"
|
||||
hello_world_message_fixture.text = new_text
|
||||
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, hello_world_message_fixture, actor=default_user)
|
||||
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(text=new_text), actor=other_user)
|
||||
assert updated is not None
|
||||
assert updated.text == new_text
|
||||
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
|
||||
assert retrieved.text == new_text
|
||||
|
||||
# Assert that orm metadata fields are populated
|
||||
assert retrieved.created_by_id == default_user.id
|
||||
assert retrieved.last_updated_by_id == other_user.id
|
||||
|
||||
|
||||
def test_message_delete(server: SyncServer, hello_world_message_fixture, default_user):
|
||||
"""Test deleting a message"""
|
||||
|
||||
Reference in New Issue
Block a user