From 36f105c1c76a2a1a62ab3658cf7acdc8dbc7f64c Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sat, 7 Sep 2024 20:03:16 -0700 Subject: [PATCH] feat: add support for returning type `MemGPTMessage` from cursor `GET` (#1723) --- memgpt/client/client.py | 10 +- memgpt/schemas/memgpt_message.py | 26 ++++ memgpt/schemas/message.py | 97 +++++++++++++- memgpt/server/rest_api/agents/message.py | 24 ++-- memgpt/server/server.py | 26 +++- tests/test_server.py | 154 ++++++++++++++++++++++- 6 files changed, 312 insertions(+), 25 deletions(-) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index ea9ec62d..584d23d7 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -642,7 +642,7 @@ class RESTClient(AbstractClient): messages (List[Message]): List of messages """ - params = {"before": before, "after": after, "limit": limit} + params = {"before": before, "after": after, "limit": limit, "msg_object": True} response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages", params=params, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to get messages: {response.text}") @@ -2151,7 +2151,13 @@ class LocalClient(AbstractClient): self.interface.clear() return self.server.get_agent_recall_cursor( - user_id=self.user_id, agent_id=agent_id, before=before, after=after, limit=limit, reverse=True + user_id=self.user_id, + agent_id=agent_id, + before=before, + after=after, + limit=limit, + reverse=True, + return_message_object=True, ) def list_models(self) -> List[LLMConfig]: diff --git a/memgpt/schemas/memgpt_message.py b/memgpt/schemas/memgpt_message.py index b6bb0f19..1182ea52 100644 --- a/memgpt/schemas/memgpt_message.py +++ b/memgpt/schemas/memgpt_message.py @@ -29,6 +29,32 @@ class MemGPTMessage(BaseModel): return dt.isoformat(timespec="seconds") +class SystemMessage(MemGPTMessage): + """ + A message generated by the system. Never streamed back on a response, only used for cursor pagination. + + Attributes: + message (str): The message sent by the system + id (str): The ID of the message + date (datetime): The date the message was created in ISO format + """ + + message: str + + +class UserMessage(MemGPTMessage): + """ + A message sent by the user. Never streamed back on a response, only used for cursor pagination. + + Attributes: + message (str): The message sent by the user + id (str): The ID of the message + date (datetime): The date the message was created in ISO format + """ + + message: str + + class InternalMonologue(MemGPTMessage): """ Representation of an agent's internal monologue. diff --git a/memgpt/schemas/message.py b/memgpt/schemas/message.py index 5c8adfe4..a3331e98 100644 --- a/memgpt/schemas/message.py +++ b/memgpt/schemas/message.py @@ -2,7 +2,7 @@ import copy import json import warnings from datetime import datetime, timezone -from typing import List, Optional, Union +from typing import List, Optional from pydantic import Field, field_validator @@ -10,7 +10,15 @@ from memgpt.constants import TOOL_CALL_ID_MAX_LEN from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG from memgpt.schemas.enums import MessageRole from memgpt.schemas.memgpt_base import MemGPTBase -from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage +from memgpt.schemas.memgpt_message import ( + FunctionCall, + FunctionCallMessage, + FunctionReturn, + InternalMonologue, + MemGPTMessage, + SystemMessage, + UserMessage, +) from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction from memgpt.utils import get_utc_time, is_utc_datetime, json_dumps @@ -96,11 +104,90 @@ class Message(BaseMessage): json_message["created_at"] = self.created_at.isoformat() return json_message - def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]: + def to_memgpt_message(self) -> List[MemGPTMessage]: """Convert message object (in DB format) to the style used by the original MemGPT API""" - # NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call) - raise NotImplementedError + messages = [] + + if self.role == MessageRole.assistant: + if self.text is not None: + # This is type InnerThoughts + messages.append( + InternalMonologue( + id=self.id, + date=self.created_at, + internal_monologue=self.text, + ) + ) + if self.tool_calls is not None: + # This is type FunctionCall + for tool_call in self.tool_calls: + messages.append( + FunctionCallMessage( + id=self.id, + date=self.created_at, + function_call=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + ) + elif self.role == MessageRole.tool: + # This is type FunctionReturn + # Try to interpret the function return, recall that this is how we packaged: + # def package_function_response(was_success, response_string, timestamp=None): + # formatted_time = get_local_time() if timestamp is None else timestamp + # packaged_message = { + # "status": "OK" if was_success else "Failed", + # "message": response_string, + # "time": formatted_time, + # } + assert self.text is not None, self + try: + function_return = json.loads(self.text) + status = function_return["status"] + if status == "OK": + status_enum = "success" + elif status == "Failed": + status_enum = "error" + else: + raise ValueError(f"Invalid status: {status}") + except json.JSONDecodeError: + raise ValueError(f"Failed to decode function return: {self.text}") + messages.append( + # TODO make sure this is what the API returns + # function_return may not match exactly... + FunctionReturn( + id=self.id, + date=self.created_at, + function_return=self.text, + status=status_enum, + ) + ) + elif self.role == MessageRole.user: + # This is type UserMessage + assert self.text is not None, self + messages.append( + UserMessage( + id=self.id, + date=self.created_at, + message=self.text, + ) + ) + elif self.role == MessageRole.system: + # This is type SystemMessage + assert self.text is not None, self + messages.append( + SystemMessage( + id=self.id, + date=self.created_at, + message=self.text, + ) + ) + else: + raise ValueError(self.role) + + return messages @staticmethod def dict_to_message( diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index b0da83fc..ee1e53ca 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -129,32 +129,26 @@ async def send_message_to_agent( def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents/{agent_id}/messages/context/", tags=["agents"], response_model=List[Message]) - def get_agent_messages_in_context( - agent_id: str, - start: int = Query(..., description="Message index to start on (reverse chronological)."), - count: int = Query(..., description="How many messages to retrieve."), - user_id: str = Depends(get_current_user_with_server), - ): - """ - Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate. - """ - interface.clear() - messages = server.get_agent_messages(agent_id=agent_id, start=start, count=count) - return messages - @router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message]) def get_agent_messages( agent_id: str, before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), limit: int = Query(10, description="Maximum number of messages to retrieve."), + msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."), user_id: str = Depends(get_current_user_with_server), ): """ Retrieve message history for an agent. """ interface.clear() - return server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, before=before, limit=limit, reverse=True) + return server.get_agent_recall_cursor( + user_id=user_id, + agent_id=agent_id, + before=before, + limit=limit, + reverse=True, + return_message_object=msg_object, + ) @router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse) async def send_message( diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 4fa7869a..f5a2f23a 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -54,6 +54,7 @@ from memgpt.schemas.embedding_config import EmbeddingConfig from memgpt.schemas.enums import JobStatus from memgpt.schemas.job import Job from memgpt.schemas.llm_config import LLMConfig +from memgpt.schemas.memgpt_message import MemGPTMessage from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from memgpt.schemas.message import Message from memgpt.schemas.openai.chat_completion_response import UsageStatistics @@ -990,7 +991,13 @@ class SyncServer(Server): message = memgpt_agent.persistence_manager.recall_memory.storage.get(id=message_id) return message - def get_agent_messages(self, agent_id: str, start: int, count: int) -> List[Message]: + def get_agent_messages( + self, + agent_id: str, + start: int, + count: int, + return_message_object: bool = True, + ) -> Union[List[Message], List[MemGPTMessage]]: """Paginated query of all messages in agent message queue""" # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(agent_id=agent_id) @@ -1025,6 +1032,7 @@ class SyncServer(Server): # return messages in reverse chronological order messages = sorted(page, key=lambda x: x.created_at, reverse=True) + assert all(isinstance(m, Message) for m in messages) ## Convert to json ## Add a tag indicating in-context or not @@ -1033,6 +1041,9 @@ class SyncServer(Server): # for d in json_messages: # d["in_context"] = True if str(d["id"]) in in_context_message_ids else False + if not return_message_object: + messages = [msg for m in messages for msg in m.to_memgpt_message()] + return messages def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]: @@ -1118,7 +1129,8 @@ class SyncServer(Server): order_by: Optional[str] = "created_at", order: Optional[str] = "asc", reverse: Optional[bool] = False, - ) -> List[Message]: + return_message_object: bool = True, + ) -> Union[List[Message], List[MemGPTMessage]]: if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: @@ -1131,6 +1143,16 @@ class SyncServer(Server): cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor( after=after, before=before, limit=limit, order_by=order_by, reverse=reverse ) + + assert all(isinstance(m, Message) for m in records) + + if not return_message_object: + # If we're GETing messages in reverse, we need to reverse the inner list (generated by to_memgpt_message) + if reverse: + records = [msg for m in records for msg in m.to_memgpt_message()[::-1]] + else: + records = [msg for m in records for msg in m.to_memgpt_message()] + return records def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]: diff --git a/tests/test_server.py b/tests/test_server.py index 11d8f059..1046b28b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,11 +4,21 @@ import pytest import memgpt.utils as utils from memgpt.constants import BASE_TOOLS +from memgpt.schemas.enums import MessageRole utils.DEBUG = True from memgpt.config import MemGPTConfig from memgpt.schemas.agent import CreateAgent +from memgpt.schemas.memgpt_message import ( + FunctionCallMessage, + FunctionReturn, + InternalMonologue, + MemGPTMessage, + SystemMessage, + UserMessage, +) from memgpt.schemas.memory import ChatMemory +from memgpt.schemas.message import Message from memgpt.schemas.source import SourceCreate from memgpt.schemas.user import UserCreate from memgpt.server.server import SyncServer @@ -83,7 +93,7 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id): @pytest.mark.order(1) -def test_user_message(server, user_id, agent_id): +def test_user_message_memory(server, user_id, agent_id): try: server.user_message(user_id=user_id, agent_id=agent_id, message="/memory") raise Exception("user_message call should have failed") @@ -223,3 +233,145 @@ def test_get_archival_memory(server, user_id, agent_id): # test safe empty return passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000) assert len(passage_none) == 0 + + +def _test_get_messages_memgpt_format(server, user_id, agent_id, reverse=False): + """Reverse is off by default, the GET goes in chronological order""" + + messages = server.get_agent_recall_cursor( + user_id=user_id, + agent_id=agent_id, + limit=1000, + reverse=reverse, + ) + # messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000) + assert all(isinstance(m, Message) for m in messages) + + memgpt_messages = server.get_agent_recall_cursor( + user_id=user_id, + agent_id=agent_id, + limit=1000, + reverse=reverse, + return_message_object=False, + ) + # memgpt_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False) + assert all(isinstance(m, MemGPTMessage) for m in memgpt_messages) + + # Loop through `messages` while also looping through `memgpt_messages` + # Each message in `messages` should have 1+ corresponding messages in `memgpt_messages` + # If role of message (in `messages`) is `assistant`, + # then there should be two messages in `memgpt_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage. + # If role of message (in `messages`) is `user`, then there should be one message in `memgpt_messages` which is type UserMessage. + # If role of message (in `messages`) is `system`, then there should be one message in `memgpt_messages` which is type SystemMessage. + # If role of message (in `messages`) is `tool`, then there should be one message in `memgpt_messages` which is type FunctionReturn. + + print("MESSAGES (obj):") + for i, m in enumerate(messages): + # print(m) + print(f"{i}: {m.role}, {m.text[:50]}...") + # print(m.role) + + print("MEMGPT_MESSAGES:") + for i, m in enumerate(memgpt_messages): + print(f"{i}: {type(m)} ...{str(m)[-50:]}") + + # Collect system messages and their texts + system_messages = [m for m in messages if m.role == MessageRole.system] + system_texts = [m.text for m in system_messages] + + # If there are multiple system messages, print the diff + if len(system_messages) > 1: + print("Differences between system messages:") + for i in range(len(system_texts) - 1): + for j in range(i + 1, len(system_texts)): + import difflib + + diff = difflib.unified_diff( + system_texts[i].splitlines(), + system_texts[j].splitlines(), + fromfile=f"System Message {i+1}", + tofile=f"System Message {j+1}", + lineterm="", + ) + print("\n".join(diff)) + else: + print("There is only one or no system message.") + + memgpt_message_index = 0 + for i, message in enumerate(messages): + assert isinstance(message, Message) + + print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}") + while memgpt_message_index < len(memgpt_messages): + memgpt_message = memgpt_messages[memgpt_message_index] + print(f"memgpt_message {memgpt_message_index}: {str(memgpt_message)[:50]}") + + if message.role == MessageRole.assistant: + print(f"i={i}, M=assistant, MM={type(memgpt_message)}") + + # If reverse, function call will come first + if reverse: + + # If there are multiple tool calls, we should have multiple back to back FunctionCallMessages + if message.tool_calls is not None: + for tool_call in message.tool_calls: + assert isinstance(memgpt_message, FunctionCallMessage) + memgpt_message_index += 1 + memgpt_message = memgpt_messages[memgpt_message_index] + + if message.text is not None: + assert isinstance(memgpt_message, InternalMonologue) + memgpt_message_index += 1 + memgpt_message = memgpt_messages[memgpt_message_index] + else: + # If there's no inner thoughts then there needs to be a tool call + assert message.tool_calls is not None + + else: + + if message.text is not None: + assert isinstance(memgpt_message, InternalMonologue) + memgpt_message_index += 1 + memgpt_message = memgpt_messages[memgpt_message_index] + else: + # If there's no inner thoughts then there needs to be a tool call + assert message.tool_calls is not None + + # If there are multiple tool calls, we should have multiple back to back FunctionCallMessages + if message.tool_calls is not None: + for tool_call in message.tool_calls: + assert isinstance(memgpt_message, FunctionCallMessage) + assert tool_call.function.name == memgpt_message.function_call.name + assert tool_call.function.arguments == memgpt_message.function_call.arguments + memgpt_message_index += 1 + memgpt_message = memgpt_messages[memgpt_message_index] + + elif message.role == MessageRole.user: + print(f"i={i}, M=user, MM={type(memgpt_message)}") + assert isinstance(memgpt_message, UserMessage) + assert message.text == memgpt_message.message + memgpt_message_index += 1 + + elif message.role == MessageRole.system: + print(f"i={i}, M=system, MM={type(memgpt_message)}") + assert isinstance(memgpt_message, SystemMessage) + assert message.text == memgpt_message.message + memgpt_message_index += 1 + + elif message.role == MessageRole.tool: + print(f"i={i}, M=tool, MM={type(memgpt_message)}") + assert isinstance(memgpt_message, FunctionReturn) + # Check the the value in `text` is the same + assert message.text == memgpt_message.function_return + memgpt_message_index += 1 + + else: + raise ValueError(f"Unexpected message role: {message.role}") + + # Move to the next message in the original messages list + break + + +def test_get_messages_memgpt_format(server, user_id, agent_id): + _test_get_messages_memgpt_format(server, user_id, agent_id, reverse=False) + _test_get_messages_memgpt_format(server, user_id, agent_id, reverse=True)