feat: add support for returning type MemGPTMessage from cursor GET (#1723)

This commit is contained in:
Charles Packer
2024-09-07 20:03:16 -07:00
committed by GitHub
parent e6247692e5
commit 36f105c1c7
6 changed files with 312 additions and 25 deletions

View File

@@ -642,7 +642,7 @@ class RESTClient(AbstractClient):
messages (List[Message]): List of messages
"""
params = {"before": before, "after": after, "limit": limit}
params = {"before": before, "after": after, "limit": limit, "msg_object": True}
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages", params=params, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get messages: {response.text}")
@@ -2151,7 +2151,13 @@ class LocalClient(AbstractClient):
self.interface.clear()
return self.server.get_agent_recall_cursor(
user_id=self.user_id, agent_id=agent_id, before=before, after=after, limit=limit, reverse=True
user_id=self.user_id,
agent_id=agent_id,
before=before,
after=after,
limit=limit,
reverse=True,
return_message_object=True,
)
def list_models(self) -> List[LLMConfig]:

View File

@@ -29,6 +29,32 @@ class MemGPTMessage(BaseModel):
return dt.isoformat(timespec="seconds")
class SystemMessage(MemGPTMessage):
"""
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
Attributes:
message (str): The message sent by the system
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message: str
class UserMessage(MemGPTMessage):
"""
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
Attributes:
message (str): The message sent by the user
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message: str
class InternalMonologue(MemGPTMessage):
"""
Representation of an agent's internal monologue.

View File

@@ -2,7 +2,7 @@ import copy
import json
import warnings
from datetime import datetime, timezone
from typing import List, Optional, Union
from typing import List, Optional
from pydantic import Field, field_validator
@@ -10,7 +10,15 @@ from memgpt.constants import TOOL_CALL_ID_MAX_LEN
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
from memgpt.schemas.enums import MessageRole
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.memgpt_message import (
FunctionCall,
FunctionCallMessage,
FunctionReturn,
InternalMonologue,
MemGPTMessage,
SystemMessage,
UserMessage,
)
from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from memgpt.utils import get_utc_time, is_utc_datetime, json_dumps
@@ -96,11 +104,90 @@ class Message(BaseMessage):
json_message["created_at"] = self.created_at.isoformat()
return json_message
def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]:
def to_memgpt_message(self) -> List[MemGPTMessage]:
"""Convert message object (in DB format) to the style used by the original MemGPT API"""
# NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call)
raise NotImplementedError
messages = []
if self.role == MessageRole.assistant:
if self.text is not None:
# This is type InnerThoughts
messages.append(
InternalMonologue(
id=self.id,
date=self.created_at,
internal_monologue=self.text,
)
)
if self.tool_calls is not None:
# This is type FunctionCall
for tool_call in self.tool_calls:
messages.append(
FunctionCallMessage(
id=self.id,
date=self.created_at,
function_call=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
)
elif self.role == MessageRole.tool:
# This is type FunctionReturn
# Try to interpret the function return, recall that this is how we packaged:
# def package_function_response(was_success, response_string, timestamp=None):
# formatted_time = get_local_time() if timestamp is None else timestamp
# packaged_message = {
# "status": "OK" if was_success else "Failed",
# "message": response_string,
# "time": formatted_time,
# }
assert self.text is not None, self
try:
function_return = json.loads(self.text)
status = function_return["status"]
if status == "OK":
status_enum = "success"
elif status == "Failed":
status_enum = "error"
else:
raise ValueError(f"Invalid status: {status}")
except json.JSONDecodeError:
raise ValueError(f"Failed to decode function return: {self.text}")
messages.append(
# TODO make sure this is what the API returns
# function_return may not match exactly...
FunctionReturn(
id=self.id,
date=self.created_at,
function_return=self.text,
status=status_enum,
)
)
elif self.role == MessageRole.user:
# This is type UserMessage
assert self.text is not None, self
messages.append(
UserMessage(
id=self.id,
date=self.created_at,
message=self.text,
)
)
elif self.role == MessageRole.system:
# This is type SystemMessage
assert self.text is not None, self
messages.append(
SystemMessage(
id=self.id,
date=self.created_at,
message=self.text,
)
)
else:
raise ValueError(self.role)
return messages
@staticmethod
def dict_to_message(

View File

@@ -129,32 +129,26 @@ async def send_message_to_agent(
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/messages/context/", tags=["agents"], response_model=List[Message])
def get_agent_messages_in_context(
agent_id: str,
start: int = Query(..., description="Message index to start on (reverse chronological)."),
count: int = Query(..., description="How many messages to retrieve."),
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
"""
interface.clear()
messages = server.get_agent_messages(agent_id=agent_id, start=start, count=count)
return messages
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message])
def get_agent_messages(
agent_id: str,
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."),
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve message history for an agent.
"""
interface.clear()
return server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, before=before, limit=limit, reverse=True)
return server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
before=before,
limit=limit,
reverse=True,
return_message_object=msg_object,
)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse)
async def send_message(

View File

@@ -54,6 +54,7 @@ from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.job import Job
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memgpt_message import MemGPTMessage
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_response import UsageStatistics
@@ -990,7 +991,13 @@ class SyncServer(Server):
message = memgpt_agent.persistence_manager.recall_memory.storage.get(id=message_id)
return message
def get_agent_messages(self, agent_id: str, start: int, count: int) -> List[Message]:
def get_agent_messages(
self,
agent_id: str,
start: int,
count: int,
return_message_object: bool = True,
) -> Union[List[Message], List[MemGPTMessage]]:
"""Paginated query of all messages in agent message queue"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(agent_id=agent_id)
@@ -1025,6 +1032,7 @@ class SyncServer(Server):
# return messages in reverse chronological order
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
assert all(isinstance(m, Message) for m in messages)
## Convert to json
## Add a tag indicating in-context or not
@@ -1033,6 +1041,9 @@ class SyncServer(Server):
# for d in json_messages:
# d["in_context"] = True if str(d["id"]) in in_context_message_ids else False
if not return_message_object:
messages = [msg for m in messages for msg in m.to_memgpt_message()]
return messages
def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]:
@@ -1118,7 +1129,8 @@ class SyncServer(Server):
order_by: Optional[str] = "created_at",
order: Optional[str] = "asc",
reverse: Optional[bool] = False,
) -> List[Message]:
return_message_object: bool = True,
) -> Union[List[Message], List[MemGPTMessage]]:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
@@ -1131,6 +1143,16 @@ class SyncServer(Server):
cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor(
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
)
assert all(isinstance(m, Message) for m in records)
if not return_message_object:
# If we're GETing messages in reverse, we need to reverse the inner list (generated by to_memgpt_message)
if reverse:
records = [msg for m in records for msg in m.to_memgpt_message()[::-1]]
else:
records = [msg for m in records for msg in m.to_memgpt_message()]
return records
def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]:

View File

@@ -4,11 +4,21 @@ import pytest
import memgpt.utils as utils
from memgpt.constants import BASE_TOOLS
from memgpt.schemas.enums import MessageRole
utils.DEBUG = True
from memgpt.config import MemGPTConfig
from memgpt.schemas.agent import CreateAgent
from memgpt.schemas.memgpt_message import (
FunctionCallMessage,
FunctionReturn,
InternalMonologue,
MemGPTMessage,
SystemMessage,
UserMessage,
)
from memgpt.schemas.memory import ChatMemory
from memgpt.schemas.message import Message
from memgpt.schemas.source import SourceCreate
from memgpt.schemas.user import UserCreate
from memgpt.server.server import SyncServer
@@ -83,7 +93,7 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id):
@pytest.mark.order(1)
def test_user_message(server, user_id, agent_id):
def test_user_message_memory(server, user_id, agent_id):
try:
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
raise Exception("user_message call should have failed")
@@ -223,3 +233,145 @@ def test_get_archival_memory(server, user_id, agent_id):
# test safe empty return
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
assert len(passage_none) == 0
def _test_get_messages_memgpt_format(server, user_id, agent_id, reverse=False):
"""Reverse is off by default, the GET goes in chronological order"""
messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
)
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
assert all(isinstance(m, Message) for m in messages)
memgpt_messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
return_message_object=False,
)
# memgpt_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
assert all(isinstance(m, MemGPTMessage) for m in memgpt_messages)
# Loop through `messages` while also looping through `memgpt_messages`
# Each message in `messages` should have 1+ corresponding messages in `memgpt_messages`
# If role of message (in `messages`) is `assistant`,
# then there should be two messages in `memgpt_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
# If role of message (in `messages`) is `user`, then there should be one message in `memgpt_messages` which is type UserMessage.
# If role of message (in `messages`) is `system`, then there should be one message in `memgpt_messages` which is type SystemMessage.
# If role of message (in `messages`) is `tool`, then there should be one message in `memgpt_messages` which is type FunctionReturn.
print("MESSAGES (obj):")
for i, m in enumerate(messages):
# print(m)
print(f"{i}: {m.role}, {m.text[:50]}...")
# print(m.role)
print("MEMGPT_MESSAGES:")
for i, m in enumerate(memgpt_messages):
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
# Collect system messages and their texts
system_messages = [m for m in messages if m.role == MessageRole.system]
system_texts = [m.text for m in system_messages]
# If there are multiple system messages, print the diff
if len(system_messages) > 1:
print("Differences between system messages:")
for i in range(len(system_texts) - 1):
for j in range(i + 1, len(system_texts)):
import difflib
diff = difflib.unified_diff(
system_texts[i].splitlines(),
system_texts[j].splitlines(),
fromfile=f"System Message {i+1}",
tofile=f"System Message {j+1}",
lineterm="",
)
print("\n".join(diff))
else:
print("There is only one or no system message.")
memgpt_message_index = 0
for i, message in enumerate(messages):
assert isinstance(message, Message)
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
while memgpt_message_index < len(memgpt_messages):
memgpt_message = memgpt_messages[memgpt_message_index]
print(f"memgpt_message {memgpt_message_index}: {str(memgpt_message)[:50]}")
if message.role == MessageRole.assistant:
print(f"i={i}, M=assistant, MM={type(memgpt_message)}")
# If reverse, function call will come first
if reverse:
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
for tool_call in message.tool_calls:
assert isinstance(memgpt_message, FunctionCallMessage)
memgpt_message_index += 1
memgpt_message = memgpt_messages[memgpt_message_index]
if message.text is not None:
assert isinstance(memgpt_message, InternalMonologue)
memgpt_message_index += 1
memgpt_message = memgpt_messages[memgpt_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
else:
if message.text is not None:
assert isinstance(memgpt_message, InternalMonologue)
memgpt_message_index += 1
memgpt_message = memgpt_messages[memgpt_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
for tool_call in message.tool_calls:
assert isinstance(memgpt_message, FunctionCallMessage)
assert tool_call.function.name == memgpt_message.function_call.name
assert tool_call.function.arguments == memgpt_message.function_call.arguments
memgpt_message_index += 1
memgpt_message = memgpt_messages[memgpt_message_index]
elif message.role == MessageRole.user:
print(f"i={i}, M=user, MM={type(memgpt_message)}")
assert isinstance(memgpt_message, UserMessage)
assert message.text == memgpt_message.message
memgpt_message_index += 1
elif message.role == MessageRole.system:
print(f"i={i}, M=system, MM={type(memgpt_message)}")
assert isinstance(memgpt_message, SystemMessage)
assert message.text == memgpt_message.message
memgpt_message_index += 1
elif message.role == MessageRole.tool:
print(f"i={i}, M=tool, MM={type(memgpt_message)}")
assert isinstance(memgpt_message, FunctionReturn)
# Check the the value in `text` is the same
assert message.text == memgpt_message.function_return
memgpt_message_index += 1
else:
raise ValueError(f"Unexpected message role: {message.role}")
# Move to the next message in the original messages list
break
def test_get_messages_memgpt_format(server, user_id, agent_id):
_test_get_messages_memgpt_format(server, user_id, agent_id, reverse=False)
_test_get_messages_memgpt_format(server, user_id, agent_id, reverse=True)