feat: Get in-context Message.id values from server (#851)

This commit is contained in:
Sarah Wooders
2024-01-18 12:42:55 -08:00
committed by GitHub
parent 61897921fd
commit 2f7ccb1807
3 changed files with 39 additions and 3 deletions

View File

@@ -307,13 +307,20 @@ class BaseRecallMemory(RecallMemory):
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
def get_all(self, start=0, count=None):
results = self.storage.get_all(start, count)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)
def text_search(self, query_string, count=None, start=None):
results = self.storage.query_text(query_string, count, start)
return results, len(results)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)
def date_search(self, start_date, end_date, count=None, start=None):
results = self.storage.query_date(start_date, end_date, count, start)
return results, len(results)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)
def __repr__(self) -> str:
total = self.storage.size()

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Union, Callable, Optional, Tuple
from typing import Union, Callable, Optional, Tuple, List
import uuid
import json
import logging
@@ -35,6 +35,21 @@ from memgpt.data_types import (
Message,
ToolCall,
)
from memgpt.data_types import (
Source,
Passage,
Document,
User,
AgentState,
LLMConfig,
EmbeddingConfig,
Message,
ToolCall,
LLMConfig,
EmbeddingConfig,
Message,
ToolCall,
)
# TODO use custom interface
from memgpt.interface import CLIInterface # for printing to terminal
@@ -677,6 +692,12 @@ class SyncServer(LockingServer):
return memory_obj
def get_in_context_message_ids(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> List[uuid.UUID]:
"""Get the message ids of the in-context messages in the agent's memory"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
return [m.id for m in memgpt_agent._messages]
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of all messages in agent message queue"""
if self.ms.get_user(user_id=user_id) is None:

View File

@@ -114,6 +114,14 @@ def test_server():
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test in-context message ids
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id)
assert len(in_context_ids) == len(messages_3)
assert isinstance(in_context_ids[0], uuid.UUID)
message_ids = [m["id"] for m in messages_3]
for message_id in message_ids:
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"