feat: modify message modification route to be via LettaMessage (#1184)
This commit is contained in:
@@ -640,30 +640,6 @@ class RESTClient(AbstractClient):
|
||||
# refresh and return agent
|
||||
return self.get_agent(agent_state.id)
|
||||
|
||||
def update_message(
|
||||
self,
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
role: Optional[MessageRole] = None,
|
||||
text: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
tool_calls: Optional[List[OpenAIToolCall]] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> Message:
|
||||
request = MessageUpdate(
|
||||
role=role,
|
||||
content=text,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update message: {response.text}")
|
||||
return Message(**response.json())
|
||||
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -2436,30 +2412,6 @@ class LocalClient(AbstractClient):
|
||||
# TODO: get full agent state
|
||||
return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user)
|
||||
|
||||
def update_message(
|
||||
self,
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
role: Optional[MessageRole] = None,
|
||||
text: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
tool_calls: Optional[List[OpenAIToolCall]] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> Message:
|
||||
message = self.server.update_agent_message(
|
||||
agent_id=agent_id,
|
||||
message_id=message_id,
|
||||
request=MessageUpdate(
|
||||
role=role,
|
||||
content=text,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
return message
|
||||
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
||||
@@ -236,6 +236,32 @@ LettaMessageUnion = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class UpdateSystemMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
|
||||
|
||||
class UpdateUserMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
|
||||
|
||||
class UpdateReasoningMessage(BaseModel):
|
||||
reasoning: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
|
||||
|
||||
class UpdateAssistantMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
|
||||
|
||||
LettaMessageUpdateUnion = Annotated[
|
||||
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
|
||||
Field(discriminator="message_type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_message_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
|
||||
@@ -74,7 +74,7 @@ class MessageUpdate(BaseModel):
|
||||
"""Request to update a message"""
|
||||
|
||||
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
||||
content: Optional[Union[str, List[MessageContentUnion]]] = Field(..., description="The content of the message.")
|
||||
content: Optional[Union[str, List[MessageContentUnion]]] = Field(None, description="The content of the message.")
|
||||
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
||||
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
||||
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
|
||||
@@ -15,7 +15,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock # , BlockLabelUpdate, BlockLimitUpdate
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
|
||||
@@ -526,20 +526,20 @@ def list_messages(
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=Message, operation_id="modify_message")
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUpdateUnion, operation_id="modify_message")
|
||||
def modify_message(
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
request: MessageUpdate = Body(...),
|
||||
request: LettaMessageUpdateUnion = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
# TODO: Get rid of agent_id here, it's not really relevant
|
||||
# TODO: support modifying tool calls/returns
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
|
||||
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from letta.log import get_logger
|
||||
@@ -7,6 +9,7 @@ from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -64,6 +67,44 @@ class MessageManager:
|
||||
"""Create multiple messages."""
|
||||
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
|
||||
|
||||
@enforce_types
|
||||
def update_message_by_letta_message(
|
||||
self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser
|
||||
) -> PydanticMessage:
|
||||
"""
|
||||
Updated the underlying messages table giving an update specified to the user-facing LettaMessage
|
||||
"""
|
||||
message = self.get_message_by_id(message_id=message_id, actor=actor)
|
||||
if letta_message_update.message_type == "assistant_message":
|
||||
# modify the tool call for send_message
|
||||
# TODO: fix this if we add parallel tool calls
|
||||
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
||||
assert (
|
||||
message.tool_calls[0].function.name == "send_message"
|
||||
), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||||
original_args = json.loads(message.tool_calls[0].function.arguments)
|
||||
original_args["message"] = letta_message_update.content # override the assistant message
|
||||
update_tool_call = message.tool_calls[0].__deepcopy__()
|
||||
update_tool_call.function.arguments = json.dumps(original_args)
|
||||
|
||||
update_message = MessageUpdate(tool_calls=[update_tool_call])
|
||||
elif letta_message_update.message_type == "reasoning_message":
|
||||
update_message = MessageUpdate(content=letta_message_update.reasoning)
|
||||
elif letta_message_update.message_type == "user_message" or letta_message_update.message_type == "system_message":
|
||||
update_message = MessageUpdate(content=letta_message_update.content)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type for modification: {letta_message_update.message_type}")
|
||||
|
||||
message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor)
|
||||
|
||||
# convert back to LettaMessage
|
||||
for letta_msg in message.to_letta_message(use_assistant_message=True):
|
||||
if letta_msg.message_type == letta_message_update.message_type:
|
||||
return letta_msg
|
||||
|
||||
# raise error if message type got modified
|
||||
raise ValueError(f"Message type got modified: {letta_message_update.message_type}")
|
||||
|
||||
@enforce_types
|
||||
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
|
||||
"""
|
||||
|
||||
@@ -536,21 +536,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
client.delete_source(source.id)
|
||||
|
||||
|
||||
def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can update the details of a message"""
|
||||
|
||||
# create a message
|
||||
message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
print("Messages=", message_response)
|
||||
assert isinstance(message_response, LettaResponse)
|
||||
assert isinstance(message_response.messages[-1], AssistantMessage)
|
||||
message = message_response.messages[-1]
|
||||
|
||||
new_text = "this is a secret message"
|
||||
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
|
||||
assert new_message.text == new_text
|
||||
|
||||
|
||||
def test_organization(client: RESTClient):
|
||||
if isinstance(client, LocalClient):
|
||||
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
|
||||
|
||||
@@ -24,6 +24,7 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage, UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
@@ -1153,6 +1154,73 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
|
||||
|
||||
def test_modify_letta_message(server: SyncServer, sarah_agent, default_user):
|
||||
"""
|
||||
Test updating a message.
|
||||
"""
|
||||
|
||||
messages = server.message_manager.list_messages_for_agent(agent_id=sarah_agent.id, actor=default_user)
|
||||
letta_messages = PydanticMessage.to_letta_messages_from_list(messages=messages)
|
||||
|
||||
system_message = [msg for msg in letta_messages if msg.message_type == "system_message"][0]
|
||||
assistant_message = [msg for msg in letta_messages if msg.message_type == "assistant_message"][0]
|
||||
user_message = [msg for msg in letta_messages if msg.message_type == "user_message"][0]
|
||||
reasoning_message = [msg for msg in letta_messages if msg.message_type == "reasoning_message"][0]
|
||||
|
||||
# user message
|
||||
update_user_message = UpdateUserMessage(content="Hello, Sarah!")
|
||||
original_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
|
||||
assert original_user_message.content[0].text != update_user_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=user_message.id, letta_message_update=update_user_message, actor=default_user
|
||||
)
|
||||
updated_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
|
||||
assert updated_user_message.content[0].text == update_user_message.content
|
||||
|
||||
# system message
|
||||
update_system_message = UpdateSystemMessage(content="You are a friendly assistant!")
|
||||
original_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
|
||||
assert original_system_message.content[0].text != update_system_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=system_message.id, letta_message_update=update_system_message, actor=default_user
|
||||
)
|
||||
updated_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
|
||||
assert updated_system_message.content[0].text == update_system_message.content
|
||||
|
||||
# reasoning message
|
||||
update_reasoning_message = UpdateReasoningMessage(reasoning="I am thinking")
|
||||
original_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
|
||||
assert original_reasoning_message.content[0].text != update_reasoning_message.reasoning
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=reasoning_message.id, letta_message_update=update_reasoning_message, actor=default_user
|
||||
)
|
||||
updated_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
|
||||
assert updated_reasoning_message.content[0].text == update_reasoning_message.reasoning
|
||||
|
||||
# assistant message
|
||||
def parse_send_message(tool_call):
|
||||
import json
|
||||
|
||||
function_call = tool_call.function
|
||||
arguments = json.loads(function_call.arguments)
|
||||
return arguments["message"]
|
||||
|
||||
update_assistant_message = UpdateAssistantMessage(content="I am an agent!")
|
||||
original_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
|
||||
print("ORIGINAL", original_assistant_message.tool_calls)
|
||||
print("MESSAGE", parse_send_message(original_assistant_message.tool_calls[0]))
|
||||
assert parse_send_message(original_assistant_message.tool_calls[0]) != update_assistant_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=assistant_message.id, letta_message_update=update_assistant_message, actor=default_user
|
||||
)
|
||||
updated_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
|
||||
print("UPDATED", updated_assistant_message.tool_calls)
|
||||
print("MESSAGE", parse_send_message(updated_assistant_message.tool_calls[0]))
|
||||
assert parse_send_message(updated_assistant_message.tool_calls[0]) == update_assistant_message.content
|
||||
|
||||
# TODO: tool calls/responses
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Blocks Relationship
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user