diff --git a/memgpt/memory.py b/memgpt/memory.py index 0cc1a9e4..f8d42649 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -307,13 +307,20 @@ class BaseRecallMemory(RecallMemory): # TODO: have some mechanism for cleanup otherwise will lead to OOM self.cache = {} + def get_all(self, start=0, count=None): + results = self.storage.get_all(start, count) + results_json = [message.to_openai_dict() for message in results] + return results_json, len(results) + def text_search(self, query_string, count=None, start=None): results = self.storage.query_text(query_string, count, start) - return results, len(results) + results_json = [message.to_openai_dict() for message in results] + return results_json, len(results) def date_search(self, start_date, end_date, count=None, start=None): results = self.storage.query_date(start_date, end_date, count, start) - return results, len(results) + results_json = [message.to_openai_dict() for message in results] + return results_json, len(results) def __repr__(self) -> str: total = self.storage.size() diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 8fa4cdb6..f605b668 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Union, Callable, Optional, Tuple +from typing import Union, Callable, Optional, Tuple, List import uuid import json import logging @@ -35,6 +35,21 @@ from memgpt.data_types import ( Message, ToolCall, ) +from memgpt.data_types import ( + Source, + Passage, + Document, + User, + AgentState, + LLMConfig, + EmbeddingConfig, + Message, + ToolCall, + LLMConfig, + EmbeddingConfig, + Message, + ToolCall, +) # TODO use custom interface from memgpt.interface import CLIInterface # for printing to terminal @@ -677,6 +692,12 @@ class SyncServer(LockingServer): return memory_obj + def get_in_context_message_ids(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> List[uuid.UUID]: + """Get the message ids of the in-context messages in the agent's memory""" + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + return [m.id for m in memgpt_agent._messages] + def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of all messages in agent message queue""" if self.ms.get_user(user_id=user_id) is None: diff --git a/tests/test_server.py b/tests/test_server.py index f57470cf..12652899 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -114,6 +114,14 @@ def test_server(): cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1) assert len(messages_4) == 1 + # test in-context message ids + in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id) + assert len(in_context_ids) == len(messages_3) + assert isinstance(in_context_ids[0], uuid.UUID) + message_ids = [m["id"] for m in messages_3] + for message_id in message_ids: + assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}" + # test archival memory cursor pagination cursor1, passages_1 = server.get_agent_archival_cursor( user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"