feat: Add paginated memory queries (#825)
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
@@ -589,7 +589,7 @@ class SyncServer(LockingServer):
|
||||
return memory_obj
|
||||
|
||||
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||
"""Paginated query of in-context messages in agent message queue"""
|
||||
"""Paginated query of all messages in agent message queue"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
@@ -600,20 +600,52 @@ class SyncServer(LockingServer):
|
||||
if start < 0 or count < 0:
|
||||
raise ValueError("Start and count values should be non-negative")
|
||||
|
||||
# Reverse the list to make it in reverse chronological order
|
||||
reversed_messages = memgpt_agent.messages[::-1]
|
||||
if start + count < len(memgpt_agent.messages): # messages can be returned from whats in memory
|
||||
# Reverse the list to make it in reverse chronological order
|
||||
reversed_messages = memgpt_agent.messages[::-1]
|
||||
# Check if start is within the range of the list
|
||||
if start >= len(reversed_messages):
|
||||
raise IndexError("Start index is out of range")
|
||||
|
||||
# Check if start is within the range of the list
|
||||
if start >= len(reversed_messages):
|
||||
raise IndexError("Start index is out of range")
|
||||
# Calculate the end index, ensuring it does not exceed the list length
|
||||
end_index = min(start + count, len(reversed_messages))
|
||||
|
||||
# Calculate the end index, ensuring it does not exceed the list length
|
||||
end_index = min(start + count, len(reversed_messages))
|
||||
# Slice the list for pagination
|
||||
paginated_messages = reversed_messages[start:end_index]
|
||||
|
||||
# Slice the list for pagination
|
||||
paginated_messages = reversed_messages[start:end_index]
|
||||
# convert to message objects:
|
||||
messages = [memgpt_agent.persistence_manager.json_to_message(m) for m in paginated_messages]
|
||||
else:
|
||||
# need to access persistence manager for additional messages
|
||||
db_iterator = memgpt_agent.persistence_manager.recall_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||
|
||||
return paginated_messages
|
||||
# get a single page of messages
|
||||
# TODO: handle stop iteration
|
||||
page = next(db_iterator, [])
|
||||
|
||||
# return messages in reverse chronological order
|
||||
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
|
||||
|
||||
# convert to json
|
||||
json_messages = [vars(record) for record in messages]
|
||||
return json_messages
|
||||
|
||||
def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||
"""Paginated query of all messages in agent archival memory"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
# iterate over records
|
||||
db_iterator = memgpt_agent.persistence_manager.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||
|
||||
# get a single page of messages
|
||||
page = next(db_iterator, [])
|
||||
json_passages = [vars(record) for record in page]
|
||||
return json_passages
|
||||
|
||||
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
|
||||
"""Return the config of an agent"""
|
||||
|
||||
Reference in New Issue
Block a user