diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index 6253b82d..a297531b 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -337,8 +337,8 @@ class SQLStorageConnector(StorageConnector): db_records = self.session.query(self.db_model).filter(*filters).all() return [record.to_record() for record in db_records] - def get(self, id: str) -> Optional[Record]: - db_record = self.session.query(self.db_model).get(id) + def get(self, id: uuid.UUID) -> Optional[Record]: + db_record = self.session.get(self.db_model, id) if db_record is None: return None return db_record.to_record() diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 07f039f0..54d08585 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -386,7 +386,7 @@ class AgentState: self.created_at = created_at if created_at is not None else datetime.now() # state - self.state = state + self.state = {} if not state else state # def __eq__(self, other): # if not isinstance(other, AgentState): diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 4f703991..64b42cf9 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -20,7 +20,7 @@ import memgpt.presets.presets as presets import memgpt.utils as utils import memgpt.server.utils as server_utils from memgpt.persistence_manager import PersistenceManager, LocalStateManager -from memgpt.data_types import Source, Passage, Document, User, AgentState +from memgpt.data_types import Source, Passage, Document, User, AgentState, LLMConfig, EmbeddingConfig, Message, ToolCall # TODO use custom interface from memgpt.interface import CLIInterface # for printing to terminal @@ -167,13 +167,47 @@ class SyncServer(LockingServer): # Initialize the connection to the DB self.config = MemGPTConfig.load() + + # Ensure valid database configuration + # TODO: add back once tests are matched + # assert ( + # self.config.metadata_storage_type == "postgres" + # ), f"Invalid metadata_storage_type for server: {self.config.metadata_storage_type}" + # assert ( + # self.config.archival_storage_type == "postgres" + # ), f"Invalid archival_storage_type for server: {self.config.archival_storage_type}" + # assert self.config.recall_storage_type == "postgres", f"Invalid recall_storage_type for server: {self.config.recall_storage_type}" + + # Generate default LLM/Embedding configs for the server + # TODO: we may also want to do the same thing with default persona/human/etc. + self.server_llm_config = LLMConfig( + model=self.config.model, + model_endpoint_type=self.config.model_endpoint_type, + model_endpoint=self.config.model_endpoint, + model_wrapper=self.config.model_wrapper, + context_window=self.config.context_window, + openai_key=self.config.openai_key, + azure_key=self.config.azure_key, + azure_endpoint=self.config.azure_endpoint, + azure_version=self.config.azure_version, + azure_deployment=self.config.azure_deployment, + ) + self.server_embedding_config = EmbeddingConfig( + embedding_endpoint_type=self.config.embedding_endpoint_type, + embedding_endpoint=self.config.embedding_endpoint, + embedding_dim=self.config.embedding_dim, + openai_key=self.config.openai_key, + ) + + # Initialize the metadata store self.ms = MetadataStore(self.config) - # Create the default user - base_user_id = uuid.UUID(self.config.anon_clientid) - if not self.ms.get_user(user_id=base_user_id): - base_user = User(id=base_user_id) - self.ms.create_user(base_user) + # NOTE: removed, since server should be multi-user + ## Create the default user + # base_user_id = uuid.UUID(self.config.anon_clientid) + # if not self.ms.get_user(user_id=base_user_id): + # base_user = User(id=base_user_id) + # self.ms.create_user(base_user) def save_agents(self): """Saves all the agents that are in the in-memory object store""" @@ -421,7 +455,6 @@ class SyncServer(LockingServer): @LockingServer.agent_lock_decorator def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: """Process an incoming user message and feed it through the MemGPT agent""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -443,7 +476,6 @@ class SyncServer(LockingServer): @LockingServer.agent_lock_decorator def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None: """Process an incoming system message and feed it through the MemGPT agent""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -465,7 +497,6 @@ class SyncServer(LockingServer): @LockingServer.agent_lock_decorator def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]: """Run a command on the agent""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -475,6 +506,26 @@ class SyncServer(LockingServer): command = command[1:] # strip the prefix return self._command(user_id=user_id, agent_id=agent_id, command=command) + def create_user( + self, + user_config: Optional[Union[dict, User]] = {}, + ): + """Create a new user using a config""" + if not isinstance(user_config, dict): + raise ValueError(f"user_config must be provided as a dictionary") + + user = User( + id=user_config["id"] if "id" in user_config else None, + default_preset=user_config["default_preset"] if "default_preset" in user_config else "memgpt_chat", + default_persona=user_config["default_persona"] if "default_persona" in user_config else constants.DEFAULT_PERSONA, + default_human=user_config["default_human"] if "default_human" in user_config else constants.DEFAULT_HUMAN, + default_llm_config=self.server_llm_config, + default_embedding_config=self.server_embedding_config, + ) + self.ms.create_user(user) + logger.info(f"Created new user from config: {user}") + return user + def create_agent( self, user_id: uuid.UUID, @@ -483,7 +534,6 @@ class SyncServer(LockingServer): # persistence_manager: Union[PersistenceManager, None] = None, ) -> AgentState: """Create a new agent using a config""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -519,6 +569,9 @@ class SyncServer(LockingServer): logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}") try: agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface) + + # TODO: this is a hacky way to get the system prompts injected into agent into the DB + self.ms.update_agent(agent.agent_state) except Exception as e: logger.exception(e) self.ms.delete_agent(agent_id=agent_state.id) @@ -533,25 +586,19 @@ class SyncServer(LockingServer): user_id: uuid.UUID, agent_id: uuid.UUID, ): - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") - # Make sure the user owns the agent - # TODO use real user_id - USER_ID = self.config.anon_clientid - agent = self.ms.get_agent(agent_id=agent_id, user_id=USER_ID) + # TODO: Make sure the user owns the agent + agent = self.ms.get_agent(agent_id=agent_id, user_id=user_id) if agent is not None: self.ms.delete_agent(agent_id=agent_id) def list_agents(self, user_id: uuid.UUID) -> dict: """List all available agents to a user""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") - # TODO actually use the user_id that was passed into the server - user_id = uuid.UUID(self.config.anon_clientid) agents_states = self.ms.list_agents(user_id=user_id) logger.info(f"Retrieved {len(agents_states)} agents for user {user_id}:\n{[vars(s) for s in agents_states]}") return { @@ -568,9 +615,16 @@ class SyncServer(LockingServer): ], } + def get_agent(self, agent_id: uuid.UUID): + """Get the agent state""" + return self.ms.get_agent(agent_id=agent_id) + + def get_user(self, user_id: uuid.UUID) -> User: + """Get the user""" + return self.ms.get_user(user_id=user_id) + def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the memory of an agent (core memory + non-core statistics)""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -594,7 +648,6 @@ class SyncServer(LockingServer): def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of all messages in agent message queue""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -636,7 +689,6 @@ class SyncServer(LockingServer): def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: """Paginated query of all messages in agent archival memory""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -661,7 +713,6 @@ class SyncServer(LockingServer): order_by: Optional[str] = "created_at", reverse: Optional[bool] = False, ): - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -685,7 +736,6 @@ class SyncServer(LockingServer): order_by: Optional[str] = "created_at", reverse: Optional[bool] = False, ): - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -703,7 +753,6 @@ class SyncServer(LockingServer): def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: """Return the config of an agent""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -730,7 +779,6 @@ class SyncServer(LockingServer): def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict: """Update the agents core memory block, return the new state""" - user_id = uuid.UUID(self.config.anon_clientid) # TODO use real if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py index d590cbb9..c50ee494 100644 --- a/tests/test_agent_function_update.py +++ b/tests/test_agent_function_update.py @@ -34,6 +34,13 @@ def agent(): else: client = MemGPT(quickstart="memgpt_hosted") + config = MemGPTConfig.load() + + # ensure user exists + user_id = uuid.UUID(config.anon_clientid) + if not client.server.get_user(user_id=user_id): + client.server.create_user({"id": user_id}) + agent_state = client.create_agent( agent_config={ # "name": test_agent_id, @@ -42,9 +49,6 @@ def agent(): } ) - config = MemGPTConfig.load() - user_id = uuid.UUID(config.anon_clientid) - return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) diff --git a/tests/test_client.py b/tests/test_client.py index 9109338b..f7b369ce 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -23,6 +23,13 @@ def test_create_agent(): else: client = MemGPT(quickstart="memgpt_hosted") + config = MemGPTConfig.load() + + # ensure user exists + user_id = uuid.UUID(config.anon_clientid) + if not client.server.get_user(user_id=user_id): + client.server.create_user({"id": user_id}) + global test_agent_state test_agent_state = client.create_agent( agent_config={ diff --git a/tests/test_server.py b/tests/test_server.py index ff7678a6..6b2a5fef 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,7 +5,7 @@ import memgpt.utils as utils utils.DEBUG = True from memgpt.config import MemGPTConfig from memgpt.server.server import SyncServer -from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage +from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User from memgpt.embeddings import embedding_model from memgpt.metadata import MetadataStore from .utils import wipe_config, wipe_memgpt_home @@ -14,22 +14,34 @@ from .utils import wipe_config, wipe_memgpt_home def test_server(): wipe_memgpt_home() - config = MemGPTConfig.load() - - # setup config for postgres storage - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config.archival_storage_type = "postgres" - config.recall_storage_type = "postgres" + config = MemGPTConfig( + archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + archival_storage_type="postgres", + recall_storage_type="postgres", + metadata_storage_type="postgres", + # embeddings + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + openai_key=os.getenv("OPENAI_API_KEY"), + # llms + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model="gpt-4", + ) config.save() - user_id = uuid.UUID(config.anon_clientid) - ms = MetadataStore(config) server = SyncServer() + # create user + user = server.create_user() + print(f"Created user\n{user.id}") + try: fake_agent_id = uuid.uuid4() - server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?") + server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?") raise Exception("user_message call should have failed") except (KeyError, ValueError) as e: # Error is expected @@ -37,29 +49,14 @@ def test_server(): except: raise - # embedding config - if os.getenv("OPENAI_API_KEY"): - embedding_config = EmbeddingConfig( - embedding_endpoint_type="openai", - embedding_endpoint="https://api.openai.com/v1", - embedding_dim=1536, - openai_key=os.getenv("OPENAI_API_KEY"), - ) - print("Using OpenAI embeddings") - else: - embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384) - print("Using local embeddings") - agent_state = server.create_agent( - user_id=user_id, - agent_config=dict( - name="test_agent", user_id=user_id, preset="memgpt_chat", human="cs_phd", persona="sam_pov", embedding_config=embedding_config - ), + user_id=user.id, + agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"), ) print(f"Created agent\n{agent_state}") try: - server.user_message(user_id=user_id, agent_id=agent_state.id, message="/memory") + server.user_message(user_id=user.id, agent_id=agent_state.id, message="/memory") raise Exception("user_message call should have failed") except ValueError as e: # Error is expected @@ -67,47 +64,47 @@ def test_server(): except: raise - print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory")) + print(server.run_command(user_id=user.id, agent_id=agent_state.id, command="/memory")) # add data into archival memory - agent = server._load_agent(user_id=user_id, agent_id=agent_state.id) + agent = server._load_agent(user_id=user.id, agent_id=agent_state.id) archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"] - embed_model = embedding_model(embedding_config) + embed_model = embedding_model(agent.agent_state.embedding_config) for text in archival_memories: embedding = embed_model.get_text_embedding(text) agent.persistence_manager.archival_memory.storage.insert( - Passage(user_id=user_id, agent_id=agent_state.id, text=text, embedding=embedding) + Passage(user_id=user.id, agent_id=agent_state.id, text=text, embedding=embedding) ) # add data into recall memory - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") + server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") # test recall memory cursor pagination - cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=2) - cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000) - cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=1000) + cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=2) + cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000) + cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=1000) ids3 = [m["id"] for m in messages_3] ids2 = [m["id"] for m in messages_2] timestamps = [m["created_at"] for m in messages_3] print("timestamps", timestamps) assert messages_3[-1]["created_at"] < messages_3[0]["created_at"] assert len(messages_3) == len(messages_1) + len(messages_2) - cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, before=cursor1) + cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1) assert len(messages_4) == 1 # test archival memory cursor pagination cursor1, passages_1 = server.get_agent_archival_cursor( - user_id=user_id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text" + user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text" ) cursor2, passages_2 = server.get_agent_archival_cursor( - user_id=user_id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text" + user_id=user.id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text" ) cursor3, passages_3 = server.get_agent_archival_cursor( - user_id=user_id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text" + user_id=user.id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text" ) print("p1", [p["text"] for p in passages_1]) print("p2", [p["text"] for p in passages_2]) @@ -117,23 +114,23 @@ def test_server(): assert len(passages_3) == 4 # test recall memory - messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1) + messages_1 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=0, count=1) assert len(messages_1) == 1 - messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000) - messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5) + messages_2 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=1000) + messages_3 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=5) # not sure exactly how many messages there should be assert len(messages_2) > len(messages_3) # test safe empty return - messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000) + messages_none = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000) assert len(messages_none) == 0 # test archival memory - passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=0, count=1) + passage_1 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=0, count=1) assert len(passage_1) == 1 - passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1, count=1000) + passage_2 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1, count=1000) assert len(passage_2) == 4 # test safe empty return - passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000) + passage_none = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000) assert len(passage_none) == 0