fix: Updating messages (#2186)

This commit is contained in:
Matthew Zhou
2024-12-07 14:09:20 -08:00
committed by GitHub
parent 79cc78f5cb
commit 1f57569116
10 changed files with 65 additions and 84 deletions

View File

@@ -36,7 +36,7 @@ from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, UpdateMessage
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.openai.chat_completion_request import (
Tool as ChatCompletionRequestTool,
)
@@ -512,9 +512,10 @@ class Agent(BaseAgent):
for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
if m.created_at:
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
"""Set the messages in the buffer to the message IDs list"""
@@ -1405,36 +1406,10 @@ class Agent(BaseAgent):
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
)
def update_message(self, request: UpdateMessage) -> Message:
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
"""Update the details of a message associated with an agent"""
message = self.message_manager.get_message_by_id(message_id=request.id, actor=self.user)
if message is None:
raise ValueError(f"Message with id {request.id} not found")
assert isinstance(message, Message), f"Message is not a Message object: {type(message)}"
# Override fields
# NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof
if request.role:
message.role = request.role
if request.text:
message.text = request.text
if request.name:
message.name = request.name
if request.tool_calls:
assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages"
message.tool_calls = request.tool_calls
if request.tool_call_id:
assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages"
message.tool_call_id = request.tool_call_id
# Save the updated message
self.message_manager.update_message_by_id(message_id=message.id, message=message, actor=self.user)
# Return the updated message
updated_message = self.message_manager.get_message_by_id(message_id=message.id, actor=self.user)
if updated_message is None:
raise ValueError(f"Error persisting message - message with id {request.id} not found")
updated_message = self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=self.user)
return updated_message
# TODO(sarah): should we be creating a new message here, or just editing a message?
@@ -1444,10 +1419,10 @@ class Agent(BaseAgent):
msg_obj = self._messages[x]
if msg_obj.role == MessageRole.assistant:
updated_message = self.update_message(
request=UpdateMessage(
id=msg_obj.id,
message_id=msg_obj.id,
request=MessageUpdate(
text=new_thought,
)
),
)
self.refresh_message_buffer()
return updated_message
@@ -1486,10 +1461,10 @@ class Agent(BaseAgent):
# Write the update to the DB
updated_message = self.update_message(
request=UpdateMessage(
id=message_obj.id,
message_id=message_obj.id,
request=MessageUpdate(
tool_calls=message_obj.tool_calls,
)
),
)
self.refresh_message_buffer()
return updated_message