feat: modify message modification route to be via LettaMessage (#1184)

This commit is contained in:
Sarah Wooders
2025-03-06 15:33:16 -08:00
committed by GitHub
parent 5f69182063
commit 4aeaec3523
7 changed files with 141 additions and 69 deletions

View File

@@ -640,30 +640,6 @@ class RESTClient(AbstractClient):
# refresh and return agent
return self.get_agent(agent_state.id)
def update_message(
self,
agent_id: str,
message_id: str,
role: Optional[MessageRole] = None,
text: Optional[str] = None,
name: Optional[str] = None,
tool_calls: Optional[List[OpenAIToolCall]] = None,
tool_call_id: Optional[str] = None,
) -> Message:
request = MessageUpdate(
role=role,
content=text,
name=name,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
)
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
)
if response.status_code != 200:
raise ValueError(f"Failed to update message: {response.text}")
return Message(**response.json())
def update_agent(
self,
agent_id: str,
@@ -2436,30 +2412,6 @@ class LocalClient(AbstractClient):
# TODO: get full agent state
return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user)
def update_message(
self,
agent_id: str,
message_id: str,
role: Optional[MessageRole] = None,
text: Optional[str] = None,
name: Optional[str] = None,
tool_calls: Optional[List[OpenAIToolCall]] = None,
tool_call_id: Optional[str] = None,
) -> Message:
message = self.server.update_agent_message(
agent_id=agent_id,
message_id=message_id,
request=MessageUpdate(
role=role,
content=text,
name=name,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
),
actor=self.user,
)
return message
def update_agent(
self,
agent_id: str,

View File

@@ -236,6 +236,32 @@ LettaMessageUnion = Annotated[
]
class UpdateSystemMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["system_message"] = "system_message"
class UpdateUserMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["user_message"] = "user_message"
class UpdateReasoningMessage(BaseModel):
reasoning: Union[str, List[MessageContentUnion]]
message_type: Literal["reasoning_message"] = "reasoning_message"
class UpdateAssistantMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["assistant_message"] = "assistant_message"
LettaMessageUpdateUnion = Annotated[
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
Field(discriminator="message_type"),
]
def create_letta_message_union_schema():
return {
"oneOf": [

View File

@@ -74,7 +74,7 @@ class MessageUpdate(BaseModel):
"""Request to update a message"""
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
content: Optional[Union[str, List[MessageContentUnion]]] = Field(..., description="The content of the message.")
content: Optional[Union[str, List[MessageContentUnion]]] = Field(None, description="The content of the message.")
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")

View File

@@ -15,7 +15,7 @@ from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import Block, BlockUpdate, CreateBlock # , BlockLabelUpdate, BlockLimitUpdate
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
@@ -526,20 +526,20 @@ def list_messages(
)
@router.patch("/{agent_id}/messages/{message_id}", response_model=Message, operation_id="modify_message")
@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUpdateUnion, operation_id="modify_message")
def modify_message(
agent_id: str,
message_id: str,
request: MessageUpdate = Body(...),
request: LettaMessageUpdateUnion = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the details of a message associated with an agent.
"""
# TODO: Get rid of agent_id here, it's not really relevant
# TODO: support modifying tool calls/returns
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
@router.post(

View File

@@ -1,5 +1,7 @@
import json
from typing import List, Optional
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from sqlalchemy import and_, or_
from letta.log import get_logger
@@ -7,6 +9,7 @@ from letta.orm.agent import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessageUpdateUnion
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
@@ -64,6 +67,44 @@ class MessageManager:
"""Create multiple messages."""
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
@enforce_types
def update_message_by_letta_message(
self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser
) -> PydanticMessage:
"""
Updated the underlying messages table giving an update specified to the user-facing LettaMessage
"""
message = self.get_message_by_id(message_id=message_id, actor=actor)
if letta_message_update.message_type == "assistant_message":
# modify the tool call for send_message
# TODO: fix this if we add parallel tool calls
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
assert (
message.tool_calls[0].function.name == "send_message"
), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
original_args = json.loads(message.tool_calls[0].function.arguments)
original_args["message"] = letta_message_update.content # override the assistant message
update_tool_call = message.tool_calls[0].__deepcopy__()
update_tool_call.function.arguments = json.dumps(original_args)
update_message = MessageUpdate(tool_calls=[update_tool_call])
elif letta_message_update.message_type == "reasoning_message":
update_message = MessageUpdate(content=letta_message_update.reasoning)
elif letta_message_update.message_type == "user_message" or letta_message_update.message_type == "system_message":
update_message = MessageUpdate(content=letta_message_update.content)
else:
raise ValueError(f"Unsupported message type for modification: {letta_message_update.message_type}")
message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor)
# convert back to LettaMessage
for letta_msg in message.to_letta_message(use_assistant_message=True):
if letta_msg.message_type == letta_message_update.message_type:
return letta_msg
# raise error if message type got modified
raise ValueError(f"Message type got modified: {letta_message_update.message_type}")
@enforce_types
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
"""

View File

@@ -536,21 +536,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
client.delete_source(source.id)
def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can update the details of a message"""
# create a message
message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
print("Messages=", message_response)
assert isinstance(message_response, LettaResponse)
assert isinstance(message_response.messages[-1], AssistantMessage)
message = message_response.messages[-1]
new_text = "this is a secret message"
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
assert new_message.text == new_text
def test_organization(client: RESTClient):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")

View File

@@ -24,6 +24,7 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessage, UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate
@@ -1153,6 +1154,73 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1
def test_modify_letta_message(server: SyncServer, sarah_agent, default_user):
"""
Test updating a message.
"""
messages = server.message_manager.list_messages_for_agent(agent_id=sarah_agent.id, actor=default_user)
letta_messages = PydanticMessage.to_letta_messages_from_list(messages=messages)
system_message = [msg for msg in letta_messages if msg.message_type == "system_message"][0]
assistant_message = [msg for msg in letta_messages if msg.message_type == "assistant_message"][0]
user_message = [msg for msg in letta_messages if msg.message_type == "user_message"][0]
reasoning_message = [msg for msg in letta_messages if msg.message_type == "reasoning_message"][0]
# user message
update_user_message = UpdateUserMessage(content="Hello, Sarah!")
original_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
assert original_user_message.content[0].text != update_user_message.content
server.message_manager.update_message_by_letta_message(
message_id=user_message.id, letta_message_update=update_user_message, actor=default_user
)
updated_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
assert updated_user_message.content[0].text == update_user_message.content
# system message
update_system_message = UpdateSystemMessage(content="You are a friendly assistant!")
original_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
assert original_system_message.content[0].text != update_system_message.content
server.message_manager.update_message_by_letta_message(
message_id=system_message.id, letta_message_update=update_system_message, actor=default_user
)
updated_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
assert updated_system_message.content[0].text == update_system_message.content
# reasoning message
update_reasoning_message = UpdateReasoningMessage(reasoning="I am thinking")
original_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
assert original_reasoning_message.content[0].text != update_reasoning_message.reasoning
server.message_manager.update_message_by_letta_message(
message_id=reasoning_message.id, letta_message_update=update_reasoning_message, actor=default_user
)
updated_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
assert updated_reasoning_message.content[0].text == update_reasoning_message.reasoning
# assistant message
def parse_send_message(tool_call):
import json
function_call = tool_call.function
arguments = json.loads(function_call.arguments)
return arguments["message"]
update_assistant_message = UpdateAssistantMessage(content="I am an agent!")
original_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
print("ORIGINAL", original_assistant_message.tool_calls)
print("MESSAGE", parse_send_message(original_assistant_message.tool_calls[0]))
assert parse_send_message(original_assistant_message.tool_calls[0]) != update_assistant_message.content
server.message_manager.update_message_by_letta_message(
message_id=assistant_message.id, letta_message_update=update_assistant_message, actor=default_user
)
updated_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
print("UPDATED", updated_assistant_message.tool_calls)
print("MESSAGE", parse_send_message(updated_assistant_message.tool_calls[0]))
assert parse_send_message(updated_assistant_message.tool_calls[0]) == update_assistant_message.content
# TODO: tool calls/responses
# ======================================================================================================================
# AgentManager Tests - Blocks Relationship
# ======================================================================================================================