feat: add support for returning type MemGPTMessage from cursor GET (#1723)
This commit is contained in:
@@ -642,7 +642,7 @@ class RESTClient(AbstractClient):
|
||||
messages (List[Message]): List of messages
|
||||
"""
|
||||
|
||||
params = {"before": before, "after": after, "limit": limit}
|
||||
params = {"before": before, "after": after, "limit": limit, "msg_object": True}
|
||||
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get messages: {response.text}")
|
||||
@@ -2151,7 +2151,13 @@ class LocalClient(AbstractClient):
|
||||
|
||||
self.interface.clear()
|
||||
return self.server.get_agent_recall_cursor(
|
||||
user_id=self.user_id, agent_id=agent_id, before=before, after=after, limit=limit, reverse=True
|
||||
user_id=self.user_id,
|
||||
agent_id=agent_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
reverse=True,
|
||||
return_message_object=True,
|
||||
)
|
||||
|
||||
def list_models(self) -> List[LLMConfig]:
|
||||
|
||||
@@ -29,6 +29,32 @@ class MemGPTMessage(BaseModel):
|
||||
return dt.isoformat(timespec="seconds")
|
||||
|
||||
|
||||
class SystemMessage(MemGPTMessage):
|
||||
"""
|
||||
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
|
||||
|
||||
Attributes:
|
||||
message (str): The message sent by the system
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
"""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class UserMessage(MemGPTMessage):
|
||||
"""
|
||||
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
|
||||
|
||||
Attributes:
|
||||
message (str): The message sent by the user
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
"""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class InternalMonologue(MemGPTMessage):
|
||||
"""
|
||||
Representation of an agent's internal monologue.
|
||||
|
||||
@@ -2,7 +2,7 @@ import copy
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -10,7 +10,15 @@ from memgpt.constants import TOOL_CALL_ID_MAX_LEN
|
||||
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from memgpt.schemas.enums import MessageRole
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
|
||||
from memgpt.schemas.memgpt_message import (
|
||||
FunctionCall,
|
||||
FunctionCallMessage,
|
||||
FunctionReturn,
|
||||
InternalMonologue,
|
||||
MemGPTMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from memgpt.utils import get_utc_time, is_utc_datetime, json_dumps
|
||||
|
||||
@@ -96,11 +104,90 @@ class Message(BaseMessage):
|
||||
json_message["created_at"] = self.created_at.isoformat()
|
||||
return json_message
|
||||
|
||||
def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]:
|
||||
def to_memgpt_message(self) -> List[MemGPTMessage]:
|
||||
"""Convert message object (in DB format) to the style used by the original MemGPT API"""
|
||||
|
||||
# NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call)
|
||||
raise NotImplementedError
|
||||
messages = []
|
||||
|
||||
if self.role == MessageRole.assistant:
|
||||
if self.text is not None:
|
||||
# This is type InnerThoughts
|
||||
messages.append(
|
||||
InternalMonologue(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
internal_monologue=self.text,
|
||||
)
|
||||
)
|
||||
if self.tool_calls is not None:
|
||||
# This is type FunctionCall
|
||||
for tool_call in self.tool_calls:
|
||||
messages.append(
|
||||
FunctionCallMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
function_call=FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.tool:
|
||||
# This is type FunctionReturn
|
||||
# Try to interpret the function return, recall that this is how we packaged:
|
||||
# def package_function_response(was_success, response_string, timestamp=None):
|
||||
# formatted_time = get_local_time() if timestamp is None else timestamp
|
||||
# packaged_message = {
|
||||
# "status": "OK" if was_success else "Failed",
|
||||
# "message": response_string,
|
||||
# "time": formatted_time,
|
||||
# }
|
||||
assert self.text is not None, self
|
||||
try:
|
||||
function_return = json.loads(self.text)
|
||||
status = function_return["status"]
|
||||
if status == "OK":
|
||||
status_enum = "success"
|
||||
elif status == "Failed":
|
||||
status_enum = "error"
|
||||
else:
|
||||
raise ValueError(f"Invalid status: {status}")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode function return: {self.text}")
|
||||
messages.append(
|
||||
# TODO make sure this is what the API returns
|
||||
# function_return may not match exactly...
|
||||
FunctionReturn(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
function_return=self.text,
|
||||
status=status_enum,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.user:
|
||||
# This is type UserMessage
|
||||
assert self.text is not None, self
|
||||
messages.append(
|
||||
UserMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
message=self.text,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.system:
|
||||
# This is type SystemMessage
|
||||
assert self.text is not None, self
|
||||
messages.append(
|
||||
SystemMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
message=self.text,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(self.role)
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def dict_to_message(
|
||||
|
||||
@@ -129,32 +129,26 @@ async def send_message_to_agent(
|
||||
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/{agent_id}/messages/context/", tags=["agents"], response_model=List[Message])
|
||||
def get_agent_messages_in_context(
|
||||
agent_id: str,
|
||||
start: int = Query(..., description="Message index to start on (reverse chronological)."),
|
||||
count: int = Query(..., description="How many messages to retrieve."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
|
||||
"""
|
||||
interface.clear()
|
||||
messages = server.get_agent_messages(agent_id=agent_id, start=start, count=count)
|
||||
return messages
|
||||
|
||||
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message])
|
||||
def get_agent_messages(
|
||||
agent_id: str,
|
||||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return MemGPTMessage objects."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, before=before, limit=limit, reverse=True)
|
||||
return server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
before=before,
|
||||
limit=limit,
|
||||
reverse=True,
|
||||
return_message_object=msg_object,
|
||||
)
|
||||
|
||||
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse)
|
||||
async def send_message(
|
||||
|
||||
@@ -54,6 +54,7 @@ from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.enums import JobStatus
|
||||
from memgpt.schemas.job import Job
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.schemas.memgpt_message import MemGPTMessage
|
||||
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_response import UsageStatistics
|
||||
@@ -990,7 +991,13 @@ class SyncServer(Server):
|
||||
message = memgpt_agent.persistence_manager.recall_memory.storage.get(id=message_id)
|
||||
return message
|
||||
|
||||
def get_agent_messages(self, agent_id: str, start: int, count: int) -> List[Message]:
|
||||
def get_agent_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
start: int,
|
||||
count: int,
|
||||
return_message_object: bool = True,
|
||||
) -> Union[List[Message], List[MemGPTMessage]]:
|
||||
"""Paginated query of all messages in agent message queue"""
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
@@ -1025,6 +1032,7 @@ class SyncServer(Server):
|
||||
|
||||
# return messages in reverse chronological order
|
||||
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
## Convert to json
|
||||
## Add a tag indicating in-context or not
|
||||
@@ -1033,6 +1041,9 @@ class SyncServer(Server):
|
||||
# for d in json_messages:
|
||||
# d["in_context"] = True if str(d["id"]) in in_context_message_ids else False
|
||||
|
||||
if not return_message_object:
|
||||
messages = [msg for m in messages for msg in m.to_memgpt_message()]
|
||||
|
||||
return messages
|
||||
|
||||
def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]:
|
||||
@@ -1118,7 +1129,8 @@ class SyncServer(Server):
|
||||
order_by: Optional[str] = "created_at",
|
||||
order: Optional[str] = "asc",
|
||||
reverse: Optional[bool] = False,
|
||||
) -> List[Message]:
|
||||
return_message_object: bool = True,
|
||||
) -> Union[List[Message], List[MemGPTMessage]]:
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
@@ -1131,6 +1143,16 @@ class SyncServer(Server):
|
||||
cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor(
|
||||
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
||||
)
|
||||
|
||||
assert all(isinstance(m, Message) for m in records)
|
||||
|
||||
if not return_message_object:
|
||||
# If we're GETing messages in reverse, we need to reverse the inner list (generated by to_memgpt_message)
|
||||
if reverse:
|
||||
records = [msg for m in records for msg in m.to_memgpt_message()[::-1]]
|
||||
else:
|
||||
records = [msg for m in records for msg in m.to_memgpt_message()]
|
||||
|
||||
return records
|
||||
|
||||
def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]:
|
||||
|
||||
@@ -4,11 +4,21 @@ import pytest
|
||||
|
||||
import memgpt.utils as utils
|
||||
from memgpt.constants import BASE_TOOLS
|
||||
from memgpt.schemas.enums import MessageRole
|
||||
|
||||
utils.DEBUG = True
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.schemas.agent import CreateAgent
|
||||
from memgpt.schemas.memgpt_message import (
|
||||
FunctionCallMessage,
|
||||
FunctionReturn,
|
||||
InternalMonologue,
|
||||
MemGPTMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from memgpt.schemas.memory import ChatMemory
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.source import SourceCreate
|
||||
from memgpt.schemas.user import UserCreate
|
||||
from memgpt.server.server import SyncServer
|
||||
@@ -83,7 +93,7 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(1)
|
||||
def test_user_message(server, user_id, agent_id):
|
||||
def test_user_message_memory(server, user_id, agent_id):
|
||||
try:
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
|
||||
raise Exception("user_message call should have failed")
|
||||
@@ -223,3 +233,145 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
# test safe empty return
|
||||
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
def _test_get_messages_memgpt_format(server, user_id, agent_id, reverse=False):
|
||||
"""Reverse is off by default, the GET goes in chronological order"""
|
||||
|
||||
messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
)
|
||||
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
memgpt_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
return_message_object=False,
|
||||
)
|
||||
# memgpt_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
|
||||
assert all(isinstance(m, MemGPTMessage) for m in memgpt_messages)
|
||||
|
||||
# Loop through `messages` while also looping through `memgpt_messages`
|
||||
# Each message in `messages` should have 1+ corresponding messages in `memgpt_messages`
|
||||
# If role of message (in `messages`) is `assistant`,
|
||||
# then there should be two messages in `memgpt_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
|
||||
# If role of message (in `messages`) is `user`, then there should be one message in `memgpt_messages` which is type UserMessage.
|
||||
# If role of message (in `messages`) is `system`, then there should be one message in `memgpt_messages` which is type SystemMessage.
|
||||
# If role of message (in `messages`) is `tool`, then there should be one message in `memgpt_messages` which is type FunctionReturn.
|
||||
|
||||
print("MESSAGES (obj):")
|
||||
for i, m in enumerate(messages):
|
||||
# print(m)
|
||||
print(f"{i}: {m.role}, {m.text[:50]}...")
|
||||
# print(m.role)
|
||||
|
||||
print("MEMGPT_MESSAGES:")
|
||||
for i, m in enumerate(memgpt_messages):
|
||||
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
|
||||
|
||||
# Collect system messages and their texts
|
||||
system_messages = [m for m in messages if m.role == MessageRole.system]
|
||||
system_texts = [m.text for m in system_messages]
|
||||
|
||||
# If there are multiple system messages, print the diff
|
||||
if len(system_messages) > 1:
|
||||
print("Differences between system messages:")
|
||||
for i in range(len(system_texts) - 1):
|
||||
for j in range(i + 1, len(system_texts)):
|
||||
import difflib
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
system_texts[i].splitlines(),
|
||||
system_texts[j].splitlines(),
|
||||
fromfile=f"System Message {i+1}",
|
||||
tofile=f"System Message {j+1}",
|
||||
lineterm="",
|
||||
)
|
||||
print("\n".join(diff))
|
||||
else:
|
||||
print("There is only one or no system message.")
|
||||
|
||||
memgpt_message_index = 0
|
||||
for i, message in enumerate(messages):
|
||||
assert isinstance(message, Message)
|
||||
|
||||
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
|
||||
while memgpt_message_index < len(memgpt_messages):
|
||||
memgpt_message = memgpt_messages[memgpt_message_index]
|
||||
print(f"memgpt_message {memgpt_message_index}: {str(memgpt_message)[:50]}")
|
||||
|
||||
if message.role == MessageRole.assistant:
|
||||
print(f"i={i}, M=assistant, MM={type(memgpt_message)}")
|
||||
|
||||
# If reverse, function call will come first
|
||||
if reverse:
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
for tool_call in message.tool_calls:
|
||||
assert isinstance(memgpt_message, FunctionCallMessage)
|
||||
memgpt_message_index += 1
|
||||
memgpt_message = memgpt_messages[memgpt_message_index]
|
||||
|
||||
if message.text is not None:
|
||||
assert isinstance(memgpt_message, InternalMonologue)
|
||||
memgpt_message_index += 1
|
||||
memgpt_message = memgpt_messages[memgpt_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
else:
|
||||
|
||||
if message.text is not None:
|
||||
assert isinstance(memgpt_message, InternalMonologue)
|
||||
memgpt_message_index += 1
|
||||
memgpt_message = memgpt_messages[memgpt_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
for tool_call in message.tool_calls:
|
||||
assert isinstance(memgpt_message, FunctionCallMessage)
|
||||
assert tool_call.function.name == memgpt_message.function_call.name
|
||||
assert tool_call.function.arguments == memgpt_message.function_call.arguments
|
||||
memgpt_message_index += 1
|
||||
memgpt_message = memgpt_messages[memgpt_message_index]
|
||||
|
||||
elif message.role == MessageRole.user:
|
||||
print(f"i={i}, M=user, MM={type(memgpt_message)}")
|
||||
assert isinstance(memgpt_message, UserMessage)
|
||||
assert message.text == memgpt_message.message
|
||||
memgpt_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.system:
|
||||
print(f"i={i}, M=system, MM={type(memgpt_message)}")
|
||||
assert isinstance(memgpt_message, SystemMessage)
|
||||
assert message.text == memgpt_message.message
|
||||
memgpt_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.tool:
|
||||
print(f"i={i}, M=tool, MM={type(memgpt_message)}")
|
||||
assert isinstance(memgpt_message, FunctionReturn)
|
||||
# Check the the value in `text` is the same
|
||||
assert message.text == memgpt_message.function_return
|
||||
memgpt_message_index += 1
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected message role: {message.role}")
|
||||
|
||||
# Move to the next message in the original messages list
|
||||
break
|
||||
|
||||
|
||||
def test_get_messages_memgpt_format(server, user_id, agent_id):
|
||||
_test_get_messages_memgpt_format(server, user_id, agent_id, reverse=False)
|
||||
_test_get_messages_memgpt_format(server, user_id, agent_id, reverse=True)
|
||||
|
||||
Reference in New Issue
Block a user