diff --git a/letta/client/client.py b/letta/client/client.py index e26ffa0a..3e9de117 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -640,30 +640,6 @@ class RESTClient(AbstractClient): # refresh and return agent return self.get_agent(agent_state.id) - def update_message( - self, - agent_id: str, - message_id: str, - role: Optional[MessageRole] = None, - text: Optional[str] = None, - name: Optional[str] = None, - tool_calls: Optional[List[OpenAIToolCall]] = None, - tool_call_id: Optional[str] = None, - ) -> Message: - request = MessageUpdate( - role=role, - content=text, - name=name, - tool_calls=tool_calls, - tool_call_id=tool_call_id, - ) - response = requests.patch( - f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers - ) - if response.status_code != 200: - raise ValueError(f"Failed to update message: {response.text}") - return Message(**response.json()) - def update_agent( self, agent_id: str, @@ -2436,30 +2412,6 @@ class LocalClient(AbstractClient): # TODO: get full agent state return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user) - def update_message( - self, - agent_id: str, - message_id: str, - role: Optional[MessageRole] = None, - text: Optional[str] = None, - name: Optional[str] = None, - tool_calls: Optional[List[OpenAIToolCall]] = None, - tool_call_id: Optional[str] = None, - ) -> Message: - message = self.server.update_agent_message( - agent_id=agent_id, - message_id=message_id, - request=MessageUpdate( - role=role, - content=text, - name=name, - tool_calls=tool_calls, - tool_call_id=tool_call_id, - ), - actor=self.user, - ) - return message - def update_agent( self, agent_id: str, diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index b66c7c12..305420e2 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -236,6 +236,32 @@ LettaMessageUnion = Annotated[ ] +class UpdateSystemMessage(BaseModel): + content: Union[str, List[MessageContentUnion]] + message_type: Literal["system_message"] = "system_message" + + +class UpdateUserMessage(BaseModel): + content: Union[str, List[MessageContentUnion]] + message_type: Literal["user_message"] = "user_message" + + +class UpdateReasoningMessage(BaseModel): + reasoning: Union[str, List[MessageContentUnion]] + message_type: Literal["reasoning_message"] = "reasoning_message" + + +class UpdateAssistantMessage(BaseModel): + content: Union[str, List[MessageContentUnion]] + message_type: Literal["assistant_message"] = "assistant_message" + + +LettaMessageUpdateUnion = Annotated[ + Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage], + Field(discriminator="message_type"), +] + + def create_letta_message_union_schema(): return { "oneOf": [ diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 4490f7d7..7cf66366 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -74,7 +74,7 @@ class MessageUpdate(BaseModel): """Request to update a message""" role: Optional[MessageRole] = Field(None, description="The role of the participant.") - content: Optional[Union[str, List[MessageContentUnion]]] = Field(..., description="The content of the message.") + content: Optional[Union[str, List[MessageContentUnion]]] = Field(None, description="The content of the message.") # NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message) # user_id: Optional[str] = Field(None, description="The unique identifier of the user.") # agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index e859c605..3af6033d 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -15,7 +15,7 @@ from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.block import Block, BlockUpdate, CreateBlock # , BlockLabelUpdate, BlockLimitUpdate from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig -from letta.schemas.letta_message import LettaMessageUnion +from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory @@ -526,20 +526,20 @@ def list_messages( ) -@router.patch("/{agent_id}/messages/{message_id}", response_model=Message, operation_id="modify_message") +@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUpdateUnion, operation_id="modify_message") def modify_message( agent_id: str, message_id: str, - request: MessageUpdate = Body(...), + request: LettaMessageUpdateUnion = Body(...), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update the details of a message associated with an agent. """ - # TODO: Get rid of agent_id here, it's not really relevant + # TODO: support modifying tool calls/returns actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor) + return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor) @router.post( diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 26f7c27b..6cd3efc7 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,5 +1,7 @@ +import json from typing import List, Optional +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from sqlalchemy import and_, or_ from letta.log import get_logger @@ -7,6 +9,7 @@ from letta.orm.agent import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import LettaMessageUpdateUnion from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageUpdate from letta.schemas.user import User as PydanticUser @@ -64,6 +67,44 @@ class MessageManager: """Create multiple messages.""" return [self.create_message(m, actor=actor) for m in pydantic_msgs] + @enforce_types + def update_message_by_letta_message( + self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser + ) -> PydanticMessage: + """ + Updated the underlying messages table giving an update specified to the user-facing LettaMessage + """ + message = self.get_message_by_id(message_id=message_id, actor=actor) + if letta_message_update.message_type == "assistant_message": + # modify the tool call for send_message + # TODO: fix this if we add parallel tool calls + # TODO: note this only works if the AssistantMessage is generated by the standard send_message + assert ( + message.tool_calls[0].function.name == "send_message" + ), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}" + original_args = json.loads(message.tool_calls[0].function.arguments) + original_args["message"] = letta_message_update.content # override the assistant message + update_tool_call = message.tool_calls[0].__deepcopy__() + update_tool_call.function.arguments = json.dumps(original_args) + + update_message = MessageUpdate(tool_calls=[update_tool_call]) + elif letta_message_update.message_type == "reasoning_message": + update_message = MessageUpdate(content=letta_message_update.reasoning) + elif letta_message_update.message_type == "user_message" or letta_message_update.message_type == "system_message": + update_message = MessageUpdate(content=letta_message_update.content) + else: + raise ValueError(f"Unsupported message type for modification: {letta_message_update.message_type}") + + message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor) + + # convert back to LettaMessage + for letta_msg in message.to_letta_message(use_assistant_message=True): + if letta_msg.message_type == letta_message_update.message_type: + return letta_msg + + # raise error if message type got modified + raise ValueError(f"Message type got modified: {letta_message_update.message_type}") + @enforce_types def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage: """ diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 00ee65ba..2d7ed16e 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -536,21 +536,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): client.delete_source(source.id) -def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState): - """Test that we can update the details of a message""" - - # create a message - message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") - print("Messages=", message_response) - assert isinstance(message_response, LettaResponse) - assert isinstance(message_response.messages[-1], AssistantMessage) - message = message_response.messages[-1] - - new_text = "this is a secret message" - new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id) - assert new_message.text == new_text - - def test_organization(client: RESTClient): if isinstance(client, LocalClient): pytest.skip("Skipping test_organization because LocalClient does not support organizations") diff --git a/tests/test_managers.py b/tests/test_managers.py index ec4d2033..b86cf718 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -24,6 +24,7 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate, LettaRequestConfig +from letta.schemas.letta_message import LettaMessage, UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate @@ -1153,6 +1154,73 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 +def test_modify_letta_message(server: SyncServer, sarah_agent, default_user): + """ + Test updating a message. + """ + + messages = server.message_manager.list_messages_for_agent(agent_id=sarah_agent.id, actor=default_user) + letta_messages = PydanticMessage.to_letta_messages_from_list(messages=messages) + + system_message = [msg for msg in letta_messages if msg.message_type == "system_message"][0] + assistant_message = [msg for msg in letta_messages if msg.message_type == "assistant_message"][0] + user_message = [msg for msg in letta_messages if msg.message_type == "user_message"][0] + reasoning_message = [msg for msg in letta_messages if msg.message_type == "reasoning_message"][0] + + # user message + update_user_message = UpdateUserMessage(content="Hello, Sarah!") + original_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user) + assert original_user_message.content[0].text != update_user_message.content + server.message_manager.update_message_by_letta_message( + message_id=user_message.id, letta_message_update=update_user_message, actor=default_user + ) + updated_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user) + assert updated_user_message.content[0].text == update_user_message.content + + # system message + update_system_message = UpdateSystemMessage(content="You are a friendly assistant!") + original_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user) + assert original_system_message.content[0].text != update_system_message.content + server.message_manager.update_message_by_letta_message( + message_id=system_message.id, letta_message_update=update_system_message, actor=default_user + ) + updated_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user) + assert updated_system_message.content[0].text == update_system_message.content + + # reasoning message + update_reasoning_message = UpdateReasoningMessage(reasoning="I am thinking") + original_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user) + assert original_reasoning_message.content[0].text != update_reasoning_message.reasoning + server.message_manager.update_message_by_letta_message( + message_id=reasoning_message.id, letta_message_update=update_reasoning_message, actor=default_user + ) + updated_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user) + assert updated_reasoning_message.content[0].text == update_reasoning_message.reasoning + + # assistant message + def parse_send_message(tool_call): + import json + + function_call = tool_call.function + arguments = json.loads(function_call.arguments) + return arguments["message"] + + update_assistant_message = UpdateAssistantMessage(content="I am an agent!") + original_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user) + print("ORIGINAL", original_assistant_message.tool_calls) + print("MESSAGE", parse_send_message(original_assistant_message.tool_calls[0])) + assert parse_send_message(original_assistant_message.tool_calls[0]) != update_assistant_message.content + server.message_manager.update_message_by_letta_message( + message_id=assistant_message.id, letta_message_update=update_assistant_message, actor=default_user + ) + updated_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user) + print("UPDATED", updated_assistant_message.tool_calls) + print("MESSAGE", parse_send_message(updated_assistant_message.tool_calls[0])) + assert parse_send_message(updated_assistant_message.tool_calls[0]) == update_assistant_message.content + + # TODO: tool calls/responses + + # ====================================================================================================================== # AgentManager Tests - Blocks Relationship # ======================================================================================================================