diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index 8b613fc6..7b67f4ec 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -4,7 +4,8 @@ import psycopg from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime -from sqlalchemy import func +from sqlalchemy import func, or_, and_ +from sqlalchemy import desc, asc from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base from sqlalchemy.orm.session import close_all_sessions from sqlalchemy.sql import func @@ -15,7 +16,7 @@ import uuid import re from tqdm import tqdm -from typing import Optional, List, Iterator, Dict +from typing import Optional, List, Iterator, Dict, Tuple import numpy as np from tqdm import tqdm import pandas as pd @@ -67,16 +68,13 @@ class CommonVector(TypeDecorator): if value: assert isinstance(value, np.ndarray) or isinstance(value, list), f"Value must be of type np.ndarray or list, got {type(value)}" assert isinstance(value[0], float), f"Value must be of type float, got {type(value[0])}" - # print("WRITE", np.array(value).tobytes()) return np.array(value).tobytes() else: - # print("WRITE", value, type(value)) return value def process_result_value(self, value, dialect): if not value: return value - # print("dialect", dialect, type(value)) return np.frombuffer(value) @@ -125,6 +123,10 @@ def get_db_model( raise ValueError(f"User {user_id} not found") embedding_dim = user.default_embedding_config.embedding_dim + # this cannot be the case if we are making an agent-specific table + assert table_type != TableType.RECALL_MEMORY, f"Agent {agent_id} not found" + assert table_type != TableType.ARCHIVAL_MEMORY, f"Agent {agent_id} not found" + # Define a helper function to create or get the model class def create_or_get_model(class_name, base_model, table_name): if class_name in globals(): @@ -276,6 +278,57 @@ class SQLStorageConnector(StorageConnector): # Increment the offset to get the next chunk in the next iteration offset += page_size + def get_all_cursor( + self, + filters: Optional[Dict] = {}, + after: uuid.UUID = None, + before: uuid.UUID = None, + limit: Optional[int] = 1000, + order_by: str = "created_at", + reverse: bool = False, + ): + """Get all that returns a cursor (record.id) and records""" + filters = self.get_filters(filters) + + # generate query + query = self.session.query(self.db_model).filter(*filters) + # query = query.order_by(asc(self.db_model.id)) + + # records are sorted by the order_by field first, and then by the ID if two fields are the same + if reverse: + query = query.order_by(desc(getattr(self.db_model, order_by)), asc(self.db_model.id)) + else: + query = query.order_by(asc(getattr(self.db_model, order_by)), asc(self.db_model.id)) + + # cursor logic: filter records based on before/after ID + if after: + after_value = getattr(self.get(id=after), order_by) + if reverse: # if reverse, then we want to get records that are less than the after_value + sort_exp = getattr(self.db_model, order_by) < after_value + else: # otherwise, we want to get records that are greater than the after_value + sort_exp = getattr(self.db_model, order_by) > after_value + query = query.filter( + or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case + ) + if before: + before_value = getattr(self.get(id=before), order_by) + if reverse: + sort_exp = getattr(self.db_model, order_by) > before_value + else: + sort_exp = getattr(self.db_model, order_by) < before_value + query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before))) + + # get records + db_record_chunk = query.limit(limit).all() + if not db_record_chunk: + return None + records = [record.to_record() for record in db_record_chunk] + next_cursor = db_record_chunk[-1].id + assert isinstance(next_cursor, uuid.UUID) + + # return (cursor, list[records]) + return (next_cursor, records) + def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]: filters = self.get_filters(filters) if limit: diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 0ec46396..4f703991 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Union, Callable +from typing import Union, Callable, Optional, Tuple import uuid import json import logging @@ -513,11 +513,15 @@ class SyncServer(LockingServer): llm_config=agent_config["llm_config"] if "llm_config" in agent_config else user.default_llm_config, embedding_config=agent_config["embedding_config"] if "embedding_config" in agent_config else user.default_embedding_config, ) + # NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation + self.ms.create_agent(agent_state) + logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}") try: agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface) except Exception as e: logger.exception(e) + self.ms.delete_agent(agent_id=agent_state.id) raise logger.info(f"Created new agent from config: {agent}") @@ -647,6 +651,56 @@ class SyncServer(LockingServer): json_passages = [vars(record) for record in page] return json_passages + def get_agent_archival_cursor( + self, + user_id: uuid.UUID, + agent_id: uuid.UUID, + after: Optional[uuid.UUID] = None, + before: Optional[uuid.UUID] = None, + limit: Optional[int] = 100, + order_by: Optional[str] = "created_at", + reverse: Optional[bool] = False, + ): + user_id = uuid.UUID(self.config.anon_clientid) # TODO use real + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + + # iterate over recorde + cursor, records = memgpt_agent.persistence_manager.archival_memory.storage.get_all_cursor( + after=after, before=before, limit=limit, order_by=order_by, reverse=reverse + ) + json_records = [vars(record) for record in records] + return cursor, json_records + + def get_agent_recall_cursor( + self, + user_id: uuid.UUID, + agent_id: uuid.UUID, + after: Optional[uuid.UUID] = None, + before: Optional[uuid.UUID] = None, + limit: Optional[int] = 100, + order_by: Optional[str] = "created_at", + reverse: Optional[bool] = False, + ): + user_id = uuid.UUID(self.config.anon_clientid) # TODO use real + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) + + # iterate over records + cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor( + after=after, before=before, limit=limit, order_by=order_by, reverse=reverse + ) + json_records = [vars(record) for record in records] + + # TODO: mark what is in-context versus not + return cursor, json_records + def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the config of an agent""" user_id = uuid.UUID(self.config.anon_clientid) # TODO use real diff --git a/tests/test_server.py b/tests/test_server.py index 59cb9513..ff7678a6 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,5 @@ import uuid import os - import memgpt.utils as utils utils.DEBUG = True @@ -8,6 +7,7 @@ from memgpt.config import MemGPTConfig from memgpt.server.server import SyncServer from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage from memgpt.embeddings import embedding_model +from memgpt.metadata import MetadataStore from .utils import wipe_config, wipe_memgpt_home @@ -24,6 +24,7 @@ def test_server(): config.save() user_id = uuid.UUID(config.anon_clientid) + ms = MetadataStore(config) server = SyncServer() try: @@ -44,9 +45,10 @@ def test_server(): embedding_dim=1536, openai_key=os.getenv("OPENAI_API_KEY"), ) - + print("Using OpenAI embeddings") else: embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384) + print("Using local embeddings") agent_state = server.create_agent( user_id=user_id, @@ -67,41 +69,69 @@ def test_server(): print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory")) - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - - # test recall memory - messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1) - assert len(messages_1) == 1 - - messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000) - messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5) - # not sure exactly how many messages there should be - assert len(messages_2) > len(messages_3) - - # test safe empty return - messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000) - assert len(messages_none) == 0 - - # test archival memory + # add data into archival memory agent = server._load_agent(user_id=user_id, agent_id=agent_state.id) - archival_memories = ["Cinderella wore a blue dress", "Dog eat dog", "Shishir loves indian food"] + archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"] embed_model = embedding_model(embedding_config) for text in archival_memories: embedding = embed_model.get_text_embedding(text) agent.persistence_manager.archival_memory.storage.insert( Passage(user_id=user_id, agent_id=agent_state.id, text=text, embedding=embedding) ) + + # add data into recall memory + server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") + + # test recall memory cursor pagination + cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=2) + cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000) + cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=1000) + ids3 = [m["id"] for m in messages_3] + ids2 = [m["id"] for m in messages_2] + timestamps = [m["created_at"] for m in messages_3] + print("timestamps", timestamps) + assert messages_3[-1]["created_at"] < messages_3[0]["created_at"] + assert len(messages_3) == len(messages_1) + len(messages_2) + cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, before=cursor1) + assert len(messages_4) == 1 + + # test archival memory cursor pagination + cursor1, passages_1 = server.get_agent_archival_cursor( + user_id=user_id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text" + ) + cursor2, passages_2 = server.get_agent_archival_cursor( + user_id=user_id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text" + ) + cursor3, passages_3 = server.get_agent_archival_cursor( + user_id=user_id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text" + ) + print("p1", [p["text"] for p in passages_1]) + print("p2", [p["text"] for p in passages_2]) + print("p3", [p["text"] for p in passages_3]) + assert passages_1[0]["text"] == "alpha" + assert len(passages_2) == 3 + assert len(passages_3) == 4 + + # test recall memory + messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1) + assert len(messages_1) == 1 + messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000) + messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5) + # not sure exactly how many messages there should be + assert len(messages_2) > len(messages_3) + # test safe empty return + messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000) + assert len(messages_none) == 0 + + # test archival memory passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=0, count=1) assert len(passage_1) == 1 passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1, count=1000) - assert len(passage_2) == 2 - - print(passage_1) - + assert len(passage_2) == 4 # test safe empty return passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000) assert len(passage_none) == 0