feat: Get in-context Message.id values from server (#851)
This commit is contained in:
@@ -307,13 +307,20 @@ class BaseRecallMemory(RecallMemory):
|
||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||
self.cache = {}
|
||||
|
||||
def get_all(self, start=0, count=None):
|
||||
results = self.storage.get_all(start, count)
|
||||
results_json = [message.to_openai_dict() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
results = self.storage.query_text(query_string, count, start)
|
||||
return results, len(results)
|
||||
results_json = [message.to_openai_dict() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def date_search(self, start_date, end_date, count=None, start=None):
|
||||
results = self.storage.query_date(start_date, end_date, count, start)
|
||||
return results, len(results)
|
||||
results_json = [message.to_openai_dict() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
total = self.storage.size()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Union, Callable, Optional, Tuple
|
||||
from typing import Union, Callable, Optional, Tuple, List
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
@@ -35,6 +35,21 @@ from memgpt.data_types import (
|
||||
Message,
|
||||
ToolCall,
|
||||
)
|
||||
from memgpt.data_types import (
|
||||
Source,
|
||||
Passage,
|
||||
Document,
|
||||
User,
|
||||
AgentState,
|
||||
LLMConfig,
|
||||
EmbeddingConfig,
|
||||
Message,
|
||||
ToolCall,
|
||||
LLMConfig,
|
||||
EmbeddingConfig,
|
||||
Message,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
# TODO use custom interface
|
||||
from memgpt.interface import CLIInterface # for printing to terminal
|
||||
@@ -677,6 +692,12 @@ class SyncServer(LockingServer):
|
||||
|
||||
return memory_obj
|
||||
|
||||
def get_in_context_message_ids(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> List[uuid.UUID]:
|
||||
"""Get the message ids of the in-context messages in the agent's memory"""
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
return [m.id for m in memgpt_agent._messages]
|
||||
|
||||
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||
"""Paginated query of all messages in agent message queue"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
|
||||
@@ -114,6 +114,14 @@ def test_server():
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test in-context message ids
|
||||
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id)
|
||||
assert len(in_context_ids) == len(messages_3)
|
||||
assert isinstance(in_context_ids[0], uuid.UUID)
|
||||
message_ids = [m["id"] for m in messages_3]
|
||||
for message_id in message_ids:
|
||||
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
|
||||
|
||||
# test archival memory cursor pagination
|
||||
cursor1, passages_1 = server.get_agent_archival_cursor(
|
||||
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
|
||||
|
||||
Reference in New Issue
Block a user