diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d64e3b0b..bceba7e2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,4 +62,10 @@ jobs: PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | - PGVECTOR_TEST_DB_URL=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} poetry run pytest -s -vv tests + PGVECTOR_TEST_DB_URL=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} poetry run pytest -s -vv -k "not test_storage" tests + - name: Run storage tests + env: + PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + PGVECTOR_TEST_DB_URL=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} poetry run pytest -s -vv tests/test_storage.py diff --git a/memgpt/agent.py b/memgpt/agent.py index 607d15fb..21403828 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -5,12 +5,24 @@ import os import json import traceback -from memgpt.persistence_manager import LocalStateManager -from memgpt.config import AgentConfig, MemGPTConfig +from memgpt.data_types import AgentState +from memgpt.metadata import MetadataStore +from memgpt.interface import AgentInterface +from memgpt.persistence_manager import PersistenceManager, LocalStateManager +from memgpt.config import MemGPTConfig from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from memgpt.memory import CoreMemory as InContextMemory, summarize_messages from memgpt.openai_tools import create, is_context_overflow_error -from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response +from memgpt.utils import ( + get_local_time, + parse_json, + united_diff, + printd, + count_tokens, + get_schema_diff, + validate_function_response, + verify_first_message_correctness, +) from memgpt.constants import ( FIRST_MESSAGE_ATTEMPTS, MESSAGE_SUMMARY_WARNING_FRAC, @@ -146,45 +158,50 @@ def initialize_message_sequence( class Agent(object): def __init__( self, - config, - model, - system, - functions, # list of [{'schema': 'x', 'python_function': function_pointer}, ...] - interface, - persistence_manager, - persona_notes, - human_notes, - messages_total=None, - persistence_manager_init=True, - first_message_verify_mono=True, + agent_state: AgentState, + interface: AgentInterface, + # extras + messages_total=None, # TODO remove? + first_message_verify_mono=True, # TODO move to config? ): - # agent config - self.config = config + # Hold a copy of the state that was used to init the agent + self.config = agent_state # TODO: remove + self.agent_state = agent_state + + # gpt-4, gpt-3.5-turbo, ... + self.model = agent_state.llm_config.model - # gpt-4, gpt-3.5-turbo - self.model = model # Store the system instructions (used to rebuild memory) - self.system = system + if "system" not in agent_state.state: + raise ValueError(f"'system' not found in provided AgentState") + self.system = agent_state.state["system"] - # Available functions is a mapping from: - # function_name -> { - # json_schema: schema - # python_function: function - # } + if "functions" not in agent_state.state: + raise ValueError(f"'functions' not found in provided AgentState") # Store the functions schemas (this is passed as an argument to ChatCompletion) - functions_schema = [f_dict["json_schema"] for f_name, f_dict in functions.items()] - self.functions = functions_schema - # Store references to the python objects - self.functions_python = {f_name: f_dict["python_function"] for f_name, f_dict in functions.items()} + self.functions = agent_state.state["functions"] # these are the schema + # Link the actual python functions corresponding to the schemas + self.functions_python = {k: v["python_function"] for k, v in link_functions(function_schemas=self.functions).items()} + assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python # Initialize the memory object - self.memory = initialize_memory(persona_notes, human_notes) + if "persona" not in agent_state.state: + raise ValueError(f"'persona' not found in provided AgentState") + if "human" not in agent_state.state: + raise ValueError(f"'human' not found in provided AgentState") + self.memory = initialize_memory(ai_notes=agent_state.state["persona"], human_notes=agent_state.state["human"]) # Once the memory object is initialize, use it to "bake" the system message - self._messages = initialize_message_sequence( - self.model, - self.system, - self.memory, - ) + if "messages" in agent_state.state and agent_state.state["messages"] is not None: + if not isinstance(agent_state.state["messages"], list): + raise ValueError(f"'messages' in AgentState was bad type: {type(agent_state.state['messages'])}") + self._messages = agent_state.state["messages"] + else: + self._messages = initialize_message_sequence( + self.model, + self.system, + self.memory, + ) + # Interface must implement: # - internal_monologue # - assistant_message @@ -194,14 +211,9 @@ class Agent(object): # e.g., print in CLI vs send a discord message with a discord bot self.interface = interface - # Persistence manager must implement: - # - set_messages - # - get_messages - # - append_to_messages - self.persistence_manager = persistence_manager - if persistence_manager_init: - # creates a new agent object in the database - self.persistence_manager.init(self) + # Create the persistence manager object based on the AgentState info + # TODO + self.persistence_manager = LocalStateManager(agent_state=agent_state) # Keep track of the total number of messages throughout all time self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) @@ -220,6 +232,13 @@ class Agent(object): # When the summarizer is run, set this back to False (to reset) self.agent_alerted_about_memory_pressure = False + # Initialize the connection to the DB + self.memgpt_config = MemGPTConfig() + self.ms = MetadataStore(self.memgpt_config) + + # Create the agent in the DB + self.save() + @property def messages(self): return self._messages @@ -228,14 +247,14 @@ class Agent(object): def messages(self, value): raise Exception("Modifying message list directly not allowed") - def trim_messages(self, num): + def _trim_messages(self, num): """Trim messages from the front, not including the system message""" self.persistence_manager.trim_messages(num) new_messages = [self.messages[0]] + self.messages[num:] self._messages = new_messages - def prepend_to_messages(self, added_messages): + def _prepend_to_messages(self, added_messages): """Wrapper around self.messages.prepend to allow additional calls to a state/persistence manager""" self.persistence_manager.prepend_to_messages(added_messages) @@ -243,7 +262,7 @@ class Agent(object): self._messages = new_messages self.messages_total += len(added_messages) # still should increment the message counter (summaries are additions too) - def append_to_messages(self, added_messages): + def _append_to_messages(self, added_messages): """Wrapper around self.messages.append to allow additional calls to a state/persistence manager""" self.persistence_manager.append_to_messages(added_messages) @@ -256,148 +275,44 @@ class Agent(object): self._messages = new_messages self.messages_total += len(added_messages) - def swap_system_message(self, new_system_message): + def _swap_system_message(self, new_system_message): assert new_system_message["role"] == "system", new_system_message assert self.messages[0]["role"] == "system", self.messages - self.persistence_manager.swap_system_message(new_system_message) new_messages = [new_system_message] + self.messages[1:] # swap index 0 (system) self._messages = new_messages - def rebuild_memory(self): - """Rebuilds the system message with the latest memory object""" - curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt - new_system_message = initialize_message_sequence( - self.model, - self.system, - self.memory, - archival_memory=self.persistence_manager.archival_memory, - recall_memory=self.persistence_manager.recall_memory, - )[0] + def _get_ai_reply( + self, + message_sequence, + function_call="auto", + first_message=False, # hint + ): + """Get response from LLM API""" + try: + response = create( + agent_config=self.config, + messages=message_sequence, + functions=self.functions, + function_call=function_call, + # hint + first_message=first_message, + ) + # special case for 'length' + if response.choices[0].finish_reason == "length": + raise Exception("Finish reason was length (maximum context length)") - diff = united_diff(curr_system_message["content"], new_system_message["content"]) - printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") + # catches for soft errors + if response.choices[0].finish_reason not in ["stop", "function_call"]: + raise Exception(f"API call finish with bad finish reason: {response}") - # Store the memory change (if stateful) - self.persistence_manager.update_memory(self.memory) + # unpack with response.choices[0].message.content + return response + except Exception as e: + raise e - # Swap the system message out - self.swap_system_message(new_system_message) - - ### Local state management - def to_dict(self): - # TODO: select specific variables for the saves state (to eventually move to a DB) rather than checkpointing everything in the class - return { - "model": self.model, - "system": self.system, - "functions": self.functions, - "messages": self.messages, # TODO: convert to IDs - "messages_total": self.messages_total, - "memory": self.memory.to_dict(), - } - - def save_agent_state_json(self, filename): - """Save agent state to JSON""" - with open(filename, "w") as file: - json.dump(self.to_dict(), file) - - def save(self): - """Save agent state locally""" - - # save config - self.config.save() - - # save agent state to timestamped file - timestamp = get_local_time().replace(" ", "_").replace(":", "_") - filename = f"{timestamp}.json" - os.makedirs(self.config.save_state_dir(), exist_ok=True) - self.save_agent_state_json(os.path.join(self.config.save_state_dir(), filename)) - - # save the persistence manager too (recall/archival memory) - self.persistence_manager.save() - - @classmethod - def load_agent(cls, interface, agent_config: AgentConfig): - """Load saved agent state based on agent_config""" - # TODO: support loading from specific file - agent_name = agent_config.name - - # TODO: update this for metadata database - - # load state - directory = agent_config.save_state_dir() - json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory. - if not json_files: - print(f"/load error: no .json checkpoint files found") - raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}") - - # Sort files based on modified timestamp, with the latest file being the first. - filename = max(json_files, key=os.path.getmtime) - state = json.load(open(filename, "r")) - - # load persistence manager - persistence_manager = LocalStateManager.load(agent_config) - - messages = state["messages"] # TODO: reconstruct messages using recall memory + stored IDs - agent = cls( - config=agent_config, - model=state["model"], - system=state["system"], - functions=link_functions(state["functions"]), - interface=interface, - persistence_manager=persistence_manager, - persistence_manager_init=False, - persona_notes=state["memory"]["persona"], - human_notes=state["memory"]["human"], - messages_total=state["messages_total"] if "messages_total" in state else len(messages) - 1, - ) - agent._messages = messages - agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"]) - - return agent - - def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False): - """Can be used to enforce that the first message always uses send_message""" - response_message = response.choices[0].message - - # First message should be a call to send_message with a non-empty content - if require_send_message and not response_message.get("function_call"): - printd(f"First message didn't include function call: {response_message}") - return False - - function_call = response_message.get("function_call") - function_name = function_call.get("name") if function_call is not None else "" - if require_send_message and function_name != "send_message" and function_name != "archival_memory_search": - printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}") - return False - - if require_monologue and ( - not response_message.get("content") or response_message["content"] is None or response_message["content"] == "" - ): - printd(f"First message missing internal monologue: {response_message}") - return False - - if response_message.get("content"): - ### Extras - monologue = response_message.get("content") - - def contains_special_characters(s): - special_characters = '(){}[]"' - return any(char in s for char in special_characters) - - if contains_special_characters(monologue): - printd(f"First message internal monologue contained special characters: {response_message}") - return False - # if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower(): - if "functions" in monologue or "send_message" in monologue: - # Sometimes the syntax won't be correct and internal syntax will leak into message.context - printd(f"First message internal monologue contained reserved words: {response_message}") - return False - - return True - - def handle_ai_response(self, response_message): + def _handle_ai_response(self, response_message): """Handles parsing and function execution""" messages = [] # append these to the history when done @@ -533,11 +448,11 @@ class Agent(object): printd(f"This is the first message. Running extra verifier on AI response.") counter = 0 while True: - response = self.get_ai_reply( + response = self._get_ai_reply( message_sequence=input_message_sequence, first_message=True, # passed through to the prompt formatter ) - if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): + if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break counter += 1 @@ -545,7 +460,7 @@ class Agent(object): raise Exception(f"Hit first message retry limit ({first_message_retry_limit})") else: - response = self.get_ai_reply( + response = self._get_ai_reply( message_sequence=input_message_sequence, ) @@ -554,7 +469,7 @@ class Agent(object): # (if yes) Step 4: send the info on the function call and function response to LLM response_message = response.choices[0].message response_message_copy = response_message.copy() - all_response_messages, heartbeat_request, function_failed = self.handle_ai_response(response_message) + all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message) # Add the extra metadata to the assistant response # (e.g. enough metadata to enable recreating the API call) @@ -577,18 +492,18 @@ class Agent(object): current_total_tokens = response["usage"]["total_tokens"] active_memory_warning = False # We can't do summarize logic properly if context_window is undefined - if self.config.context_window is None: + if self.config.llm_config.context_window is None: # Fallback if for some reason context_window is missing, just set to the default print(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") print(f"{self.config}") - self.config.context_window = ( + self.config.llm_config.context_window = ( str(LLM_MAX_TOKENS[self.model]) if (self.model is not None and self.model in LLM_MAX_TOKENS) else str(LLM_MAX_TOKENS["DEFAULT"]) ) - if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window): + if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.llm_config.context_window): printd( - f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window)}" + f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.llm_config.context_window)}" ) # Only deliver the alert if we haven't already (this period) if not self.agent_alerted_about_memory_pressure: @@ -596,10 +511,10 @@ class Agent(object): self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this else: printd( - f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window)}" + f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.llm_config.context_window)}" ) - self.append_to_messages(all_new_messages) + self._append_to_messages(all_new_messages) return all_new_messages, heartbeat_request, function_failed, active_memory_warning except Exception as e: @@ -675,11 +590,11 @@ class Agent(object): printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}") # We can't do summarize logic properly if context_window is undefined - if self.config.context_window is None: + if self.config.llm_config.context_window is None: # Fallback if for some reason context_window is missing, just set to the default print(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") print(f"{self.config}") - self.config.context_window = ( + self.config.llm_config.context_window = ( str(LLM_MAX_TOKENS[self.model]) if (self.model is not None and self.model in LLM_MAX_TOKENS) else str(LLM_MAX_TOKENS["DEFAULT"]) @@ -696,9 +611,9 @@ class Agent(object): printd(f"Packaged into message: {summary_message}") prior_len = len(self.messages) - self.trim_messages(cutoff) + self._trim_messages(cutoff) packed_summary_message = {"role": "user", "content": summary_message} - self.prepend_to_messages([packed_summary_message]) + self._prepend_to_messages([packed_summary_message]) # reset alert self.agent_alerted_about_memory_pressure = False @@ -716,31 +631,128 @@ class Agent(object): elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60 - def get_ai_reply( - self, - message_sequence, - function_call="auto", - first_message=False, # hint - ): - """Get response from LLM API""" - try: - response = create( - agent_config=self.config, - messages=message_sequence, - functions=self.functions, - function_call=function_call, - # hint - first_message=first_message, - ) - # special case for 'length' - if response.choices[0].finish_reason == "length": - raise Exception("Finish reason was length (maximum context length)") + def rebuild_memory(self): + """Rebuilds the system message with the latest memory object""" + curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt + new_system_message = initialize_message_sequence( + self.model, + self.system, + self.memory, + archival_memory=self.persistence_manager.archival_memory, + recall_memory=self.persistence_manager.recall_memory, + )[0] - # catches for soft errors - if response.choices[0].finish_reason not in ["stop", "function_call"]: - raise Exception(f"API call finish with bad finish reason: {response}") + diff = united_diff(curr_system_message["content"], new_system_message["content"]) + printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") - # unpack with response.choices[0].message.content - return response - except Exception as e: - raise e + # Swap the system message out + self._swap_system_message(new_system_message) + + def to_agent_state(self): + # The state may have change since the last time we wrote it + updated_state = { + "persona": self.memory.persona, + "human": self.memory.human, + "system": self.system, + "functions": self.functions, + "messages": self.messages, + } + + agent_state = AgentState( + name=self.config.name, + user_id=self.config.user_id, + persona=self.config.persona, + human=self.config.human, + llm_config=self.config.llm_config, + embedding_config=self.config.embedding_config, + preset=self.config.preset, + id=self.config.id, + created_at=self.config.created_at, + state=updated_state, + ) + + return agent_state + + def save(self): + """Save agent state locally""" + + agent_state = self.to_agent_state() + # TODO(swooders) does this make sense? + # without this, even after Agent.__init__, agent.config.state["messages"] will be None + self.config = agent_state + + # Check if we need to create the agent + if not self.ms.get_agent(agent_id=agent_state.id, user_id=agent_state.user_id, agent_name=agent_state.name): + self.ms.create_agent(agent=agent_state) + else: + # Otherwise, we should update the agent + self.ms.update_agent(agent=agent_state) + + # # save config + # self.config.save() + + # # save agent state to timestamped file + # timestamp = get_local_time().replace(" ", "_").replace(":", "_") + # filename = f"{timestamp}.json" + # os.makedirs(self.config.save_state_dir(), exist_ok=True) + # self.save_agent_state_json(os.path.join(self.config.save_state_dir(), filename)) + + # # save the persistence manager too (recall/archival memory) + # self.persistence_manager.save() + + ### Local state management + # def to_dict(self): + # # TODO: select specific variables for the saves state (to eventually move to a DB) rather than checkpointing everything in the class + # return { + # "model": self.model, + # "system": self.system, + # "functions": self.functions, + # "messages": self.messages, # TODO: convert to IDs + # "messages_total": self.messages_total, + # "memory": self.memory.to_dict(), + # } + + # def save_agent_state_json(self, filename): + # """Save agent state to JSON""" + # with open(filename, "w") as file: + # json.dump(self.to_dict(), file) + + # @classmethod + # def load_agent(cls, interface, agent_config: AgentConfig): + # """Load saved agent state based on agent_config""" + # # TODO: support loading from specific file + # agent_name = agent_config.name + + # # TODO: update this for metadata database + + # # load state + # directory = agent_config.save_state_dir() + # json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory. + # if not json_files: + # print(f"/load error: no .json checkpoint files found") + # raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}") + + # # Sort files based on modified timestamp, with the latest file being the first. + # filename = max(json_files, key=os.path.getmtime) + # state = json.load(open(filename, "r")) + + # # load persistence manager + # persistence_manager = LocalStateManager.load(agent_config) + + # messages = state["messages"] # TODO: reconstruct messages using recall memory + stored IDs + # agent = cls( + # config=agent_config, + # model=state["model"], + # system=state["system"], + # functions=link_functions(state["functions"]), + # interface=interface, + # persistence_manager=persistence_manager, + # persistence_manager_init=False, + # persona_notes=state["memory"]["persona"], + # human_notes=state["memory"]["human"], + # messages_total=state["messages_total"] if "messages_total" in state else len(messages) - 1, + # ) + # agent._messages = messages + # agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"]) + + # return agent diff --git a/memgpt/connectors/chroma.py b/memgpt/agent_store/chroma.py similarity index 89% rename from memgpt/connectors/chroma.py rename to memgpt/agent_store/chroma.py index b1bcb007..34316783 100644 --- a/memgpt/connectors/chroma.py +++ b/memgpt/agent_store/chroma.py @@ -3,9 +3,9 @@ import uuid import json import re from typing import Optional, List, Iterator, Dict -from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.utils import printd, datetime_to_timestamp, timestamp_to_datetime -from memgpt.config import AgentConfig, MemGPTConfig +from memgpt.config import MemGPTConfig from memgpt.data_types import Record, Message, Passage @@ -15,9 +15,8 @@ class ChromaStorageConnector(StorageConnector): # WARNING: This is not thread safe. Do NOT do concurrent access to the same collection. # Timestamps are converted to integer timestamps for chroma (datetime not supported) - def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): - super().__init__(table_type=table_type, agent_config=agent_config) - config = MemGPTConfig.load() + def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): + super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) assert table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES, "Chroma only supports archival memory" @@ -34,6 +33,9 @@ class ChromaStorageConnector(StorageConnector): self.collection = self.client.get_or_create_collection(self.table_name) self.include = ["documents", "embeddings", "metadatas"] + # need to be converted to strings + self.uuid_fields = ["id", "user_id", "agent_id", "source_id"] + def get_filters(self, filters: Optional[Dict] = {}): # get all filters for query if filters is not None: @@ -45,16 +47,23 @@ class ChromaStorageConnector(StorageConnector): chroma_filters = [] ids = [] for key, value in filter_conditions.items(): + # filter by id if key == "id": ids = [str(value)] continue - chroma_filters.append({key: {"$eq": value}}) + + # filter by other keys + if key in self.uuid_fields: + chroma_filters.append({key: {"$eq": str(value)}}) + else: + chroma_filters.append({key: {"$eq": value}}) if len(chroma_filters) > 1: chroma_filters = {"$and": chroma_filters} + elif len(chroma_filters) == 0: + chroma_filters = {} else: chroma_filters = chroma_filters[0] - return ids, chroma_filters def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: @@ -79,6 +88,9 @@ class ChromaStorageConnector(StorageConnector): for metadata in results["metadatas"]: if "created_at" in metadata: metadata["created_at"] = timestamp_to_datetime(metadata["created_at"]) + for key, value in metadata.items(): + if key in self.uuid_fields: + metadata[key] = uuid.UUID(value) if results["embeddings"]: # may not be returned, depending on table type return [ self.type(text=text, embedding=embedding, id=uuid.UUID(record_id), **metadatas) @@ -130,6 +142,11 @@ class ChromaStorageConnector(StorageConnector): record_metadata = {} metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed metadata = {**metadata, **record_metadata} # merge with metadata + + # convert uuids to strings + for key, value in metadata.items(): + if key in self.uuid_fields: + metadata[key] = str(value) metadatas.append(metadata) return ids, documents, embeddings, metadatas diff --git a/memgpt/connectors/db.py b/memgpt/agent_store/db.py similarity index 82% rename from memgpt/connectors/db.py rename to memgpt/agent_store/db.py index 5e32f066..786f501f 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/agent_store/db.py @@ -3,12 +3,11 @@ import ast import psycopg -from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY +from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime from sqlalchemy import func from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base from sqlalchemy.orm.session import close_all_sessions from sqlalchemy.sql import func -from sqlalchemy import Column, BIGINT, String, DateTime from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy_json import mutable_json_type, MutableJson from sqlalchemy import TypeDecorator, CHAR @@ -22,11 +21,11 @@ from tqdm import tqdm import pandas as pd from memgpt.config import MemGPTConfig -from memgpt.connectors.storage import StorageConnector, TableType -from memgpt.config import AgentConfig, MemGPTConfig -from memgpt.constants import MEMGPT_DIR +from memgpt.agent_store.storage import StorageConnector, TableType +from memgpt.config import MemGPTConfig from memgpt.utils import printd -from memgpt.data_types import Record, Message, Passage, Source, ToolCall +from memgpt.data_types import Record, Message, Passage, ToolCall +from memgpt.metadata import MetadataStore from datetime import datetime @@ -34,6 +33,7 @@ from datetime import datetime # Custom UUID type class CommonUUID(TypeDecorator): impl = CHAR + cache_ok = True def load_dialect_impl(self, dialect): if dialect.name == "postgresql": @@ -55,27 +55,39 @@ class CommonUUID(TypeDecorator): class CommonVector(TypeDecorator): - """Common type for representing vectors in SQLite""" impl = BINARY + cache_ok = True def load_dialect_impl(self, dialect): return dialect.type_descriptor(BINARY()) def process_bind_param(self, value, dialect): - return np.array(value).tobytes() + if value: + assert isinstance(value, np.ndarray) or isinstance(value, list), f"Value must be of type np.ndarray or list, got {type(value)}" + assert isinstance(value[0], float), f"Value must be of type float, got {type(value[0])}" + # print("WRITE", np.array(value).tobytes()) + return np.array(value).tobytes() + else: + # print("WRITE", value, type(value)) + return value def process_result_value(self, value, dialect): - list_value = ast.literal_eval(value) - return np.array(list_value) + if not value: + return value + # print("dialect", dialect, type(value)) + return np.frombuffer(value) -class ToolCalls(TypeDecorator): +# Custom serialization / de-serialization for JSON columns + +class ToolCallColumn(TypeDecorator): """Custom type for storing List[ToolCall] as JSON""" impl = JSON + cache_ok = True def load_dialect_impl(self, dialect): return dialect.type_descriptor(JSON()) @@ -94,8 +106,17 @@ class ToolCalls(TypeDecorator): Base = declarative_base() -def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): - config = MemGPTConfig.load() +def get_db_model(config: MemGPTConfig, table_name: str, table_type: TableType, user_id, agent_id=None, dialect="postgresql"): + # get embedding dimention info + ms = MetadataStore(config) + if agent_id and ms.get_agent(agent_id): + agent = ms.get_agent(agent_id) + embedding_dim = agent.embedding_config.embedding_dim + else: + user = ms.get_user(user_id) + if user is None: + raise ValueError(f"User {user_id} not found") + embedding_dim = user.default_embedding_config.embedding_dim # Define a helper function to create or get the model class def create_or_get_model(class_name, base_model, table_name): @@ -116,10 +137,10 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) # id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) - user_id = Column(String, nullable=False) + user_id = Column(CommonUUID, nullable=False) text = Column(String, nullable=False) - doc_id = Column(String) - agent_id = Column(String) + doc_id = Column(CommonUUID) + agent_id = Column(CommonUUID) data_source = Column(String) # agent_name if agent, data_source name if from data source # vector storage @@ -128,7 +149,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): else: from pgvector.sqlalchemy import Vector - embedding = mapped_column(Vector(config.embedding_dim)) + embedding = mapped_column(Vector(embedding_dim)) metadata_ = Column(MutableJson) @@ -162,20 +183,20 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) # id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) - user_id = Column(String, nullable=False) - agent_id = Column(String, nullable=False) + user_id = Column(CommonUUID, nullable=False) + agent_id = Column(CommonUUID, nullable=False) # openai info role = Column(String, nullable=False) text = Column(String) # optional: can be null if function call model = Column(String, nullable=False) - user = Column(String) # optional: multi-agent only + name = Column(String) # optional: multi-agent only # tool call request info # if role == "assistant", this MAY be specified # if role != "assistant", this must be null # TODO align with OpenAI spec of multiple tool calls - tool_calls = Column(ToolCalls) + tool_calls = Column(ToolCallColumn) # tool call response info # if role == "tool", then this must be specified @@ -188,7 +209,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): else: from pgvector.sqlalchemy import Vector - embedding = mapped_column(Vector(config.embedding_dim)) + embedding = mapped_column(Vector(embedding_dim)) # Add a datetime column, with default value as the current time created_at = Column(DateTime(timezone=True), server_default=func.now()) @@ -201,7 +222,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): user_id=self.user_id, agent_id=self.agent_id, role=self.role, - user=self.user, + name=self.name, text=self.text, model=self.model, tool_calls=self.tool_calls, @@ -215,45 +236,22 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): class_name = f"{table_name.capitalize()}Model" + dialect return create_or_get_model(class_name, MessageModel, table_name) - elif table_type == TableType.DATA_SOURCES: - - class SourceModel(Base): - """Defines data model for storing Passages (consisting of text, embedding)""" - - __abstract__ = True # this line is necessary - - # Assuming passage_id is the primary key - # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) - user_id = Column(String, nullable=False) - name = Column(String, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - - def __repr__(self): - return f"" - - def to_record(self): - return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at) - - """Create database model for table_name""" - class_name = f"{table_name.capitalize()}Model" + dialect - return create_or_get_model(class_name, SourceModel, table_name) - else: raise ValueError(f"Table type {table_type} not implemented") class SQLStorageConnector(StorageConnector): - def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): - super().__init__(table_type=table_type, agent_config=agent_config) - self.config = MemGPTConfig.load() + def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): + super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) + self.config = config def get_filters(self, filters: Optional[Dict] = {}): if filters is not None: filter_conditions = {**self.filters, **filters} else: filter_conditions = self.filters - return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()] + all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()] + return all_filters def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: session = self.Session() @@ -367,10 +365,10 @@ class PostgresStorageConnector(SQLStorageConnector): # TODO: this should probably eventually be moved into a parent DB class - def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): + def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): from pgvector.sqlalchemy import Vector - super().__init__(table_type=table_type, agent_config=agent_config) + super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) # get storage URI if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: @@ -388,7 +386,7 @@ class PostgresStorageConnector(SQLStorageConnector): else: raise ValueError(f"Table type {table_type} not implemented") # create table - self.db_model = get_db_model(self.table_name, table_type) + self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id) self.engine = create_engine(self.uri) for c in self.db_model.__table__.columns: if c.name == "embedding": @@ -411,8 +409,8 @@ class PostgresStorageConnector(SQLStorageConnector): class SQLLiteStorageConnector(SQLStorageConnector): - def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): - super().__init__(table_type=table_type, agent_config=agent_config) + def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): + super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) # get storage URI if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: @@ -422,17 +420,13 @@ class SQLLiteStorageConnector(SQLStorageConnector): self.path = self.config.recall_storage_path if self.path is None: raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}") - elif table_type == TableType.DATA_SOURCES: - self.path = self.config.metadata_storage_path - if self.path is None: - raise ValueError(f"Must specifiy metadata_storage_path in config {self.config.metadata_storage_path}") else: raise ValueError(f"Table type {table_type} not implemented") self.path = os.path.join(self.path, f"{self.table_name}.db") # Create the SQLAlchemy engine - self.db_model = get_db_model(self.table_name, table_type, dialect="sqlite") + self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id, dialect="sqlite") self.engine = create_engine(f"sqlite:///{self.path}") Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist self.Session = sessionmaker(bind=self.engine) diff --git a/memgpt/connectors/lancedb.py b/memgpt/agent_store/lancedb.py similarity index 98% rename from memgpt/connectors/lancedb.py rename to memgpt/agent_store/lancedb.py index 1ada2272..2c3fcadf 100644 --- a/memgpt/connectors/lancedb.py +++ b/memgpt/agent_store/lancedb.py @@ -5,7 +5,7 @@ from tqdm import tqdm from typing import Optional, List, Iterator, Dict from memgpt.config import MemGPTConfig -from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.config import AgentConfig, MemGPTConfig from memgpt.constants import MEMGPT_DIR from memgpt.utils import printd @@ -87,7 +87,7 @@ def get_db_model(table_name: str, table_type: TableType): user_id=self.user_id, agent_id=self.agent_id, role=self.role, - user=self.user, + name=self.name, text=self.text, model=self.model, function_name=self.function_name, diff --git a/memgpt/connectors/storage.py b/memgpt/agent_store/storage.py similarity index 53% rename from memgpt/connectors/storage.py rename to memgpt/agent_store/storage.py index 0e9b1ccb..87414672 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/agent_store/storage.py @@ -12,7 +12,7 @@ from typing import List, Optional, Dict from tqdm import tqdm -from memgpt.config import AgentConfig, MemGPTConfig +from memgpt.config import MemGPTConfig from memgpt.data_types import Record, Passage, Document, Message, Source from memgpt.utils import printd @@ -24,9 +24,6 @@ class TableType: RECALL_MEMORY = "recall_memory" # archival memory table: memgpt_agent_recall_{agent_id} PASSAGES = "passages" # TODO DOCUMENTS = "documents" # TODO - USERS = "users" # TODO - AGENTS = "agents" # TODO - DATA_SOURCES = "data_sources" # TODO # table names used by MemGPT @@ -36,41 +33,46 @@ RECALL_TABLE_NAME = "memgpt_recall_memory_agent" # agent memory ARCHIVAL_TABLE_NAME = "memgpt_archival_memory_agent" # agent memory # external data source tables -SOURCE_TABLE_NAME = "memgpt_sources" # metadata for loaded data source PASSAGE_TABLE_NAME = "memgpt_passages" # chunked/embedded passages (from source) DOCUMENT_TABLE_NAME = "memgpt_documents" # original documents (from source) class StorageConnector: - def __init__(self, table_type: TableType, agent_config: Optional[AgentConfig] = None): - config = MemGPTConfig.load() - self.agent_config = agent_config - self.user_id = config.anon_clientid + """Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory""" + + def __init__(self, table_type: TableType, config: MemGPTConfig, user_id, agent_id=None): + self.user_id = user_id + self.agent_id = agent_id self.table_type = table_type # get object type - if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: + if table_type == TableType.ARCHIVAL_MEMORY: self.type = Passage + self.table_name = ARCHIVAL_TABLE_NAME elif table_type == TableType.RECALL_MEMORY: self.type = Message - elif table_type == TableType.DATA_SOURCES: - self.type = Source + self.table_name = RECALL_TABLE_NAME + elif table_type == TableType.DOCUMENTS: + self.type = Document + self.table_name == DOCUMENT_TABLE_NAME + elif table_type == TableType.PASSAGES: + self.type = Passage + self.table_name = PASSAGE_TABLE_NAME else: raise ValueError(f"Table type {table_type} not implemented") - - # determine name of database table - self.table_name = self.generate_table_name(agent_config, table_type=table_type) printd(f"Using table name {self.table_name}") # setup base filters for agent-specific tables if self.table_type == TableType.ARCHIVAL_MEMORY or self.table_type == TableType.RECALL_MEMORY: # agent-specific table - self.filters = {"user_id": self.user_id, "agent_id": self.agent_config.name} - elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS or self.table_type == TableType.DATA_SOURCES: + assert agent_id is not None, "Agent ID must be provided for agent-specific tables" + self.filters = {"user_id": self.user_id, "agent_id": self.agent_id} + elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS: # setup base filters for user-specific tables + assert agent_id is None, "Agent ID must not be provided for user-specific tables" self.filters = {"user_id": self.user_id} else: - self.filters = {} + raise ValueError(f"Table type {table_type} not implemented") def get_filters(self, filters: Optional[Dict] = {}): # get all filters for query @@ -80,78 +82,47 @@ class StorageConnector: filter_conditions = self.filters return filter_conditions - def generate_table_name(self, agent_config: AgentConfig, table_type: TableType): - if agent_config is not None: - # Table names for agent-specific tables - if table_type == TableType.ARCHIVAL_MEMORY: - return ARCHIVAL_TABLE_NAME - elif table_type == TableType.RECALL_MEMORY: - return RECALL_TABLE_NAME - else: - raise ValueError(f"Table type {table_type} not implemented") - else: - # table names for non-agent specific tables - if table_type == TableType.PASSAGES: - return PASSAGE_TABLE_NAME - elif table_type == TableType.DOCUMENTS: - return DOCUMENT_TABLE_NAME - elif table_type == TableType.DATA_SOURCES: - return SOURCE_TABLE_NAME - else: - raise ValueError(f"Table type {table_type} not implemented") - @staticmethod - def get_storage_connector(table_type: TableType, storage_type: Optional[str] = None, agent_config: Optional[AgentConfig] = None): - # read from config if not provided - if storage_type is None: - if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: - storage_type = MemGPTConfig.load().archival_storage_type - elif table_type == TableType.RECALL_MEMORY: - storage_type = MemGPTConfig.load().recall_storage_type - elif table_type == TableType.DATA_SOURCES or table_type == TableType.USERS or table_type == TableType.AGENTS: - storage_type = MemGPTConfig.load().metadata_storage_type - # TODO: other tables + def get_storage_connector(table_type: TableType, config: MemGPTConfig, user_id, agent_id=None): + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: + storage_type = config.archival_storage_type + elif table_type == TableType.RECALL_MEMORY: + storage_type = config.recall_storage_type + else: + raise ValueError(f"Table type {table_type} not implemented") if storage_type == "postgres": - from memgpt.connectors.db import PostgresStorageConnector + from memgpt.agent_store.db import PostgresStorageConnector - return PostgresStorageConnector(agent_config=agent_config, table_type=table_type) + return PostgresStorageConnector(table_type, config, user_id, agent_id) elif storage_type == "chroma": - from memgpt.connectors.chroma import ChromaStorageConnector + from memgpt.agent_store.chroma import ChromaStorageConnector - return ChromaStorageConnector(agent_config=agent_config, table_type=table_type) + return ChromaStorageConnector(table_type, config, user_id, agent_id) # TODO: add back # elif storage_type == "lancedb": - # from memgpt.connectors.db import LanceDBConnector + # from memgpt.agent_store.db import LanceDBConnector # return LanceDBConnector(agent_config=agent_config, table_type=table_type) - elif storage_type == "local": - from memgpt.connectors.local import InMemoryStorageConnector - - return InMemoryStorageConnector(agent_config=agent_config, table_type=table_type) - elif storage_type == "sqlite": - from memgpt.connectors.db import SQLLiteStorageConnector + from memgpt.agent_store.db import SQLLiteStorageConnector - return SQLLiteStorageConnector(agent_config=agent_config, table_type=table_type) + return SQLLiteStorageConnector(table_type, config, user_id, agent_id) else: raise NotImplementedError(f"Storage type {storage_type} not implemented") @staticmethod - def get_archival_storage_connector(agent_config: Optional[AgentConfig] = None): - return StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, agent_config=agent_config) + def get_archival_storage_connector(user_id, agent_id): + config = MemGPTConfig.load() + return StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id, agent_id) @staticmethod - def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None): - return StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, agent_config=agent_config) - - @staticmethod - def get_metadata_storage_connector(table_type: TableType): - storage_type = MemGPTConfig.load().metadata_storage_type - return StorageConnector.get_storage_connector(table_type, storage_type=storage_type) + def get_recall_storage_connector(user_id, agent_id): + config = MemGPTConfig.load() + return StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, config, user_id, agent_id) @abstractmethod def get_filters(self, filters: Optional[Dict] = {}): diff --git a/memgpt/autogen/examples/agent_autoreply.py b/memgpt/autogen/examples/agent_autoreply.py index c033e554..5488bba7 100644 --- a/memgpt/autogen/examples/agent_autoreply.py +++ b/memgpt/autogen/examples/agent_autoreply.py @@ -12,8 +12,7 @@ Begin by doing: import os import autogen from memgpt.autogen.memgpt_agent import create_memgpt_autogen_agent_from_config -from memgpt.presets.presets import DEFAULT_PRESET -from memgpt.constants import LLM_MAX_TOKENS +from memgpt.constants import LLM_MAX_TOKENS, DEFAULT_PRESET LLM_BACKEND = "openai" # LLM_BACKEND = "azure" diff --git a/memgpt/autogen/examples/agent_docs.py b/memgpt/autogen/examples/agent_docs.py index 3eaa74a2..93e84261 100644 --- a/memgpt/autogen/examples/agent_docs.py +++ b/memgpt/autogen/examples/agent_docs.py @@ -15,8 +15,7 @@ Begin by doing: import os import autogen from memgpt.autogen.memgpt_agent import create_memgpt_autogen_agent_from_config -from memgpt.presets.presets import DEFAULT_PRESET -from memgpt.constants import LLM_MAX_TOKENS +from memgpt.constants import LLM_MAX_TOKENS, DEFAULT_PRESET LLM_BACKEND = "openai" # LLM_BACKEND = "azure" diff --git a/memgpt/autogen/examples/agent_groupchat.py b/memgpt/autogen/examples/agent_groupchat.py index 99203ba1..dd15b63c 100644 --- a/memgpt/autogen/examples/agent_groupchat.py +++ b/memgpt/autogen/examples/agent_groupchat.py @@ -13,8 +13,7 @@ Begin by doing: import os import autogen from memgpt.autogen.memgpt_agent import create_memgpt_autogen_agent_from_config -from memgpt.presets.presets import DEFAULT_PRESET -from memgpt.constants import LLM_MAX_TOKENS +from memgpt.constants import LLM_MAX_TOKENS, DEFAULT_PRESET LLM_BACKEND = "openai" # LLM_BACKEND = "azure" diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index db47ecda..838f347a 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -12,7 +12,7 @@ import memgpt.presets.presets as presets from memgpt.config import AgentConfig, MemGPTConfig from memgpt.cli.cli import attach from memgpt.cli.cli_load import load_directory, load_webpage, load_index, load_database, load_vector_database -from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.agent_store.storage import StorageConnector, TableType def create_memgpt_autogen_agent_from_config( @@ -171,7 +171,7 @@ def create_autogen_memgpt_agent( } persistence_manager = LocalStateManager(**persistence_manager_kwargs) if persistence_manager is None else persistence_manager - memgpt_agent = presets.use_preset( + memgpt_agent = presets.create_agent_from_preset( agent_config.preset, agent_config, agent_config.model, diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index c4f7344f..e076c135 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -1,4 +1,5 @@ import typer +import uuid import json import requests import sys @@ -25,6 +26,8 @@ from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX from memgpt.agent import Agent from memgpt.embeddings import embedding_model from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT +from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig, User +from memgpt.metadata import MetadataStore class QuickstartChoice(Enum): @@ -356,6 +359,17 @@ def run( configure() config = MemGPTConfig.load() + # read user id from config + ms = MetadataStore(config) + user_id = uuid.UUID(config.anon_clientid) + user = ms.get_user(user_id=user_id) + if user is None: + ms.create_user(User(id=user_id)) + user = ms.get_user(user_id=user_id) + if user is None: + typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED) + sys.exit(1) + # override with command line arguments if debug: config.debug = debug @@ -364,8 +378,8 @@ def run( # determine agent to use, if not provided if not yes and not agent: - agent_files = utils.list_agent_config_files() - agents = [AgentConfig.load(f).name for f in agent_files] + agents = ms.list_agents(user_id=user.id) + agents = [a.name for a in agents] if len(agents) > 0 and not any([persona, human, model]): print() @@ -373,154 +387,166 @@ def run( if select_agent: agent = questionary.select("Select agent:", choices=agents).ask() - # configure llama index - config = MemGPTConfig.load() - original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index - sys.stdout = io.StringIO() - embed_model = embedding_model() - service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size) - set_global_service_context(service_context) - sys.stdout = original_stdout - # create agent config - if agent and AgentConfig.exists(agent): # use existing agent + if agent and ms.get_agent(agent_name=agent, user_id=user.id): # use existing agent typer.secho(f"\nšŸ” Using existing agent {agent}", fg=typer.colors.GREEN) - agent_config = AgentConfig.load(agent) - printd("State path:", agent_config.save_state_dir()) - printd("Persistent manager path:", agent_config.save_persistence_manager_dir()) - printd("Index path:", agent_config.save_agent_index_dir()) + # agent_config = AgentConfig.load(agent) + agent_state = ms.get_agent(agent_name=agent, user_id=user_id) + printd("Loading agent state:", agent_state.id) + printd("Agent state:", agent_state.state) + # printd("State path:", agent_config.save_state_dir()) + # printd("Persistent manager path:", agent_config.save_persistence_manager_dir()) + # printd("Index path:", agent_config.save_agent_index_dir()) # persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load # TODO: load prior agent state - if persona and persona != agent_config.persona: - typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing persona {agent_config.persona} with {persona}", fg=typer.colors.YELLOW) - agent_config.persona = persona - # raise ValueError(f"Cannot override {agent_config.name} existing persona {agent_config.persona} with {persona}") - if human and human != agent_config.human: - typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW) - agent_config.human = human + if persona and persona != agent_state.persona: + typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing persona {agent_state.persona} with {persona}", fg=typer.colors.YELLOW) + agent_state.persona = persona + # raise ValueError(f"Cannot override {agent_state.name} existing persona {agent_state.persona} with {persona}") + if human and human != agent_state.human: + typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing human {agent_state.human} with {human}", fg=typer.colors.YELLOW) + agent_state.human = human # raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}") # Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window) - if model and model != agent_config.model: - typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW) - agent_config.model = model - if context_window is not None and int(context_window) != agent_config.context_window: + if model and model != agent_state.llm_config.model: typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing context window {agent_config.context_window} with {context_window}", + f"{CLI_WARNING_PREFIX}Overriding existing model {agent_state.llm_config.model} with {model}", fg=typer.colors.YELLOW + ) + agent_state.llm_config.model = model + if context_window is not None and int(context_window) != agent_state.llm_config.context_window: + typer.secho( + f"{CLI_WARNING_PREFIX}Overriding existing context window {agent_state.llm_config.context_window} with {context_window}", fg=typer.colors.YELLOW, ) - agent_config.context_window = context_window - if model_wrapper and model_wrapper != agent_config.model_wrapper: + agent_state.llm_config.context_window = context_window + if model_wrapper and model_wrapper != agent_state.llm_config.model_wrapper: typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", + f"{CLI_WARNING_PREFIX}Overriding existing model wrapper {agent_state.llm_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW, ) - agent_config.model_wrapper = model_wrapper - if model_endpoint and model_endpoint != agent_config.model_endpoint: + agent_state.llm_config.model_wrapper = model_wrapper + if model_endpoint and model_endpoint != agent_state.llm_config.model_endpoint: typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", + f"{CLI_WARNING_PREFIX}Overriding existing model endpoint {agent_state.llm_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW, ) - agent_config.model_endpoint = model_endpoint - if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type: + agent_state.llm_config.model_endpoint = model_endpoint + if model_endpoint_type and model_endpoint_type != agent_state.llm_config.model_endpoint_type: typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}", + f"{CLI_WARNING_PREFIX}Overriding existing model endpoint type {agent_state.llm_config.model_endpoint_type} with {model_endpoint_type}", fg=typer.colors.YELLOW, ) - agent_config.model_endpoint_type = model_endpoint_type + agent_state.llm_config.model_endpoint_type = model_endpoint_type - # Update the agent config with any overrides - agent_config.save() + # Update the agent with any overrides + ms.update_agent(agent_state) - # Supress llama-index noise - with suppress_stdout(): - # load existing agent - memgpt_agent = Agent.load_agent(interface, agent_config) + # create agent + memgpt_agent = Agent(agent_state, interface=interface) else: # create new agent # create new agent config: override defaults with args if provided typer.secho("\n🧬 Creating new agent...", fg=typer.colors.WHITE) - agent_config = AgentConfig( - name=agent, - persona=persona, - human=human, - preset=preset, - model=model, - model_wrapper=model_wrapper, - model_endpoint_type=model_endpoint_type, - model_endpoint=model_endpoint, - context_window=context_window, - ) - # save new agent config - agent_config.save() - typer.secho(f"-> šŸ¤– Using persona profile '{agent_config.persona}'", fg=typer.colors.WHITE) - typer.secho(f"-> šŸ§‘ Using human profile '{agent_config.human}'", fg=typer.colors.WHITE) + if agent is None: + # determine agent name + # agent_count = len(ms.list_agents(user_id=user.id)) + # agent = f"agent_{agent_count}" + agent = utils.create_random_username() + + agent_state = AgentState( + name=agent, + user_id=user.id, + persona=persona if persona else user.default_persona, + human=human if human else user.default_human, + preset=preset if preset else user.default_preset, + llm_config=user.default_llm_config, + embedding_config=user.default_embedding_config, + ) + ms.create_agent(agent_state) + + typer.secho(f"-> šŸ¤– Using persona profile '{agent_state.persona}'", fg=typer.colors.WHITE) + typer.secho(f"-> šŸ§‘ Using human profile '{agent_state.human}'", fg=typer.colors.WHITE) # Supress llama-index noise - with suppress_stdout(): - # TODO: allow configrable state manager (only local is supported right now) - persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill + # TODO(swooders) add persistence manager code? or comment out? + # with suppress_stdout(): + # TODO: allow configrable state manager (only local is supported right now) + # persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill # create agent try: - memgpt_agent = presets.use_preset( - agent_config.preset, - agent_config, - agent_config.model, - utils.get_persona_text(agent_config.persona), - utils.get_human_text(agent_config.human), - interface, - persistence_manager, + memgpt_agent = presets.create_agent_from_preset( + agent_state=agent_state, + interface=interface, ) except ValueError as e: + # TODO(swooders) what's the equivalent cleanup code for the new DB refactor? typer.secho(f"Failed to create agent from provided information:\n{e}", fg=typer.colors.RED) - # Delete the directory of the failed agent - try: - # Path to the specific file - agent_config_file = agent_config.agent_config_path + # # Delete the directory of the failed agent + # try: + # # Path to the specific file + # agent_config_file = agent_config.agent_config_path - # Check if the file exists - if os.path.isfile(agent_config_file): - # Delete the file - os.remove(agent_config_file) + # # Check if the file exists + # if os.path.isfile(agent_config_file): + # # Delete the file + # os.remove(agent_config_file) - # Now, delete the directory along with any remaining files in it - agent_save_dir = os.path.join(MEMGPT_DIR, "agents", agent_config.name) - shutil.rmtree(agent_save_dir) - except: - typer.secho(f"Failed to delete agent directory during cleanup:\n{e}", fg=typer.colors.RED) + # # Now, delete the directory along with any remaining files in it + # agent_save_dir = os.path.join(MEMGPT_DIR, "agents", agent_config.name) + # shutil.rmtree(agent_save_dir) + # except: + # typer.secho(f"Failed to delete agent directory during cleanup:\n{e}", fg=typer.colors.RED) sys.exit(1) - typer.secho(f"šŸŽ‰ Created new agent '{agent_config.name}'", fg=typer.colors.GREEN) + typer.secho(f"šŸŽ‰ Created new agent '{agent_state.name}'", fg=typer.colors.GREEN) # pretty print agent config - printd(json.dumps(vars(agent_config), indent=4, sort_keys=True)) + # printd(json.dumps(vars(agent_config), indent=4, sort_keys=True)) + # printd(json.dumps(agent_init_state), indent=4, sort_keys=True)) + + # configure llama index + original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index + sys.stdout = io.StringIO() + embed_model = embedding_model(config=agent_state.embedding_config, user_id=user.id) + service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size) + set_global_service_context(service_context) + sys.stdout = original_stdout # start event loop from memgpt.main import run_agent_loop print() # extra space - run_agent_loop(memgpt_agent, first, no_verify, config) # TODO: add back no_verify + run_agent_loop(memgpt_agent, config, first, no_verify) # TODO: add back no_verify def attach( agent: str = typer.Option(help="Specify agent to attach data to"), data_source: str = typer.Option(help="Data source to attach to avent"), + user_id: uuid.UUID = None, ): + # use client ID is no user_id provided + config = MemGPTConfig.load() + if user_id is None: + user_id = uuid.UUID(config.anon_clientid) try: # loads the data contained in data source into the agent's memory - from memgpt.connectors.storage import StorageConnector, TableType + from memgpt.agent_store.storage import StorageConnector, TableType from tqdm import tqdm - agent_config = AgentConfig.load(agent) + ms = MetadataStore(config) + agent = ms.get_agent(agent_name=agent, user_id=user_id) + source = ms.get_source(source_name=data_source, user_id=user_id) + assert source is not None, f"Source {data_source} does not exist for user {user_id}" # get storage connectors with suppress_stdout(): - source_storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES) - dest_storage = StorageConnector.get_storage_connector(table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config) + source_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id=user_id) + dest_storage = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id=user_id, agent_id=agent.id) size = source_storage.size({"data_source": data_source}) - typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN) + typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN) page_size = 100 generator = source_storage.get_all_paginated(filters={"data_source": data_source}, page_size=page_size) # yields List[Passage] passages = [] @@ -530,7 +556,7 @@ def attach( # need to associated passage with agent (for filtering) for passage in passages: - passage.agent_id = agent_config.name + passage.agent_id = agent.id # insert into agent archival memory dest_storage.insert_many(passages) @@ -538,6 +564,10 @@ def attach( # save destination storage dest_storage.save() + # attach to agent + source_id = ms.get_source(source_name=data_source, user_id=user_id).id + ms.attach_source(agent_id=agent.id, source_id=source_id, user_id=user_id) + total_agent_passages = dest_storage.size() typer.secho( diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 0b480712..728f7f3a 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -1,4 +1,5 @@ import builtins +import uuid import questionary from prettytable import PrettyTable import typer @@ -10,14 +11,17 @@ from enum import Enum # from memgpt.cli import app from memgpt import utils -from memgpt.config import MemGPTConfig, AgentConfig +from memgpt.config import MemGPTConfig from memgpt.constants import MEMGPT_DIR -from memgpt.connectors.storage import StorageConnector, TableType + +# from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.constants import LLM_MAX_TOKENS from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME from memgpt.local_llm.utils import get_available_wrappers from memgpt.openai_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin from memgpt.server.utils import shorten_key_middle +from memgpt.data_types import User, LLMConfig, EmbeddingConfig +from memgpt.metadata import MetadataStore app = typer.Typer() @@ -463,6 +467,7 @@ def configure(): typer.secho(str(e), fg=typer.colors.RED) return + # TODO: remove most of this (deplicated with User table) config = MemGPTConfig( # model configs model=model, @@ -500,9 +505,44 @@ def configure(): metadata_storage_uri=recall_storage_uri, metadata_storage_path=recall_storage_path, ) + typer.secho(f"šŸ“– Saving config to {config.config_path}", fg=typer.colors.GREEN) config.save() + # create user records + ms = MetadataStore(config) + user_id = uuid.UUID(config.anon_clientid) + user = User( + id=uuid.UUID(config.anon_clientid), + default_preset=default_preset, + default_persona=default_persona, + default_human=default_human, + default_agent=default_agent, + default_llm_config=LLMConfig( + model=model, + model_endpoint=model_endpoint, + model_endpoint_type=model_endpoint_type, + model_wrapper=model_wrapper, + context_window=context_window, + ), + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type=embedding_endpoint_type, + embedding_endpoint=embedding_endpoint, + embedding_dim=embedding_dim, + embedding_model=embedding_model, + openai_key=openai_key, + azure_key=azure_creds["azure_key"], + azure_endpoint=azure_creds["azure_endpoint"], + azure_version=azure_creds["azure_version"], + azure_deployment=azure_creds["azure_deployment"], # OK if None + ), + ) + if ms.get_user(user_id): + # update user + ms.update_user(user) + else: + ms.create_user(user) + class ListChoice(str, Enum): agents = "agents" @@ -513,21 +553,24 @@ class ListChoice(str, Enum): @app.command() def list(arg: Annotated[ListChoice, typer.Argument]): + config = MemGPTConfig.load() + ms = MetadataStore(config) + user_id = uuid.UUID(config.anon_clientid) if arg == ListChoice.agents: """List all agents""" table = PrettyTable() table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"] - for agent_file in utils.list_agent_config_files(): - agent_name = os.path.basename(agent_file).replace(".json", "") - agent_config = AgentConfig.load(agent_name) + for agent in ms.list_agents(user_id=user_id): + source_ids = ms.list_attached_sources(agent_id=agent.id) + source_names = [ms.get_source(source_id=source_id).name for source_id in source_ids] table.add_row( [ - agent_name, - agent_config.model, - agent_config.persona, - agent_config.human, - ",".join(agent_config.data_sources), - agent_config.create_time, + agent.name, + agent.llm_config.model, + agent.persona, + agent.human, + ",".join(source_names), + utils.format_datetime(agent.created_at), ] ) print(table) @@ -552,20 +595,22 @@ def list(arg: Annotated[ListChoice, typer.Argument]): print(table) elif arg == ListChoice.sources: """List all data sources""" - conn = StorageConnector.get_metadata_storage_connector(table_type=TableType.DATA_SOURCES) # already filters by user - passage_conn = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES) # create table table = PrettyTable() - table.field_names = ["Name", "Created At", "Number of Passages", "Agents"] + table.field_names = ["Name", "Created At", "Agents"] # TODO: eventually look accross all storage connections # TODO: add data source stats # TODO: connect to agents # get all sources - for data_source in conn.get_all(): - num_passages = passage_conn.size({"data_source": data_source.name}) - table.add_row([data_source.name, data_source.created_at, num_passages, ""]) + for source in ms.list_sources(user_id=user_id): + # get attached agents + agent_ids = ms.list_attached_agents(source_id=source.id) + agent_names = [ms.get_agent(agent_id=agent_id).name for agent_id in agent_ids] + + table.add_row([source.name, utils.format_datetime(source.created_at), ",".join(agent_names)]) + print(table) else: raise ValueError(f"Unknown argument {arg}") diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index ebb5e9d6..743e89c4 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -11,12 +11,14 @@ memgpt load --name [ADDITIONAL ARGS] from typing import List from tqdm import tqdm import typer +import uuid from memgpt.embeddings import embedding_model -from memgpt.connectors.storage import StorageConnector +from memgpt.agent_store.storage import StorageConnector from memgpt.config import MemGPTConfig -from memgpt.data_types import Source, Passage, Document +from memgpt.metadata import MetadataStore +from memgpt.data_types import Source, Passage, Document, User from memgpt.utils import get_local_time, suppress_stdout -from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.agent_store.storage import StorageConnector, TableType from datetime import datetime @@ -30,23 +32,30 @@ from llama_index import ( app = typer.Typer() -def store_docs(name, docs, show_progress=True): +def store_docs(name, docs, user_id=None, show_progress=True): """Common function for embedding and storing documents""" config = MemGPTConfig.load() + if user_id is None: # assume running local with single user + user_id = uuid.UUID(config.anon_clientid) # record data source metadata - data_source = Source(user_id=config.anon_clientid, name=name, created_at=datetime.now()) - metadata_conn = StorageConnector.get_metadata_storage_connector(TableType.DATA_SOURCES) - if len(metadata_conn.get_all({"name": name})) > 0: - print(f"Data source {name} already exists in metadata, skipping.") - # TODO: should this error, or just add more data to this source? + ms = MetadataStore(config) + user = ms.get_user(user_id) + print("USER", user) + data_source = Source(user_id=user.id, name=name, created_at=datetime.now()) + if not ms.get_source(user_id=user.id, source_name=name): + print("Trying to add...") + ms.create_source(data_source) + print("Created source", data_source) else: - metadata_conn.insert(data_source) + print(f"Source {name} for user {user.id} already exists") # compute and record passages - storage = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=config.archival_storage_type) - embed_model = embedding_model() + print("USER ID", user.id) + storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user.id) + print("embedding config", user.default_embedding_config, user.default_embedding_config.embedding_dim) + embed_model = embedding_model(user.default_embedding_config) orig_size = storage.size() # use llama index to run embeddings code @@ -65,11 +74,11 @@ def store_docs(name, docs, show_progress=True): node.embedding = vector text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters assert ( - len(node.embedding) == config.embedding_dim - ), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}: {node.embedding}" + len(node.embedding) == user.default_embedding_config.embedding_dim + ), f"Expected embedding dimension {user.default_embedding_config.embedding_dim}, got {len(node.embedding)}: {node.embedding}" passages.append( Passage( - user_id=config.anon_clientid, + user_id=user.id, text=text, data_source=name, embedding=node.embedding, @@ -119,6 +128,7 @@ def load_directory( input_dir: str = typer.Option(None, help="Path to directory containing dataset."), input_files: List[str] = typer.Option(None, help="List of paths to files containing dataset."), recursive: bool = typer.Option(False, help="Recursively search for files in directory."), + user_id: str = typer.Option(None, help="User ID to associate with dataset."), ): try: from llama_index import SimpleDirectoryReader @@ -136,7 +146,7 @@ def load_directory( # load docs docs = reader.load_data() - store_docs(name, docs) + store_docs(name, docs, user_id) except ValueError as e: typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 0bc72451..ce4f2a1b 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -1,6 +1,7 @@ import os from typing import Dict, List, Union +from memgpt.data_types import AgentState from memgpt.cli.cli import QuickstartChoice from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice from memgpt.config import MemGPTConfig, AgentConfig @@ -12,6 +13,7 @@ from memgpt.server.server import SyncServer class Client(object): def __init__( self, + user_id: str = None, auto_save: bool = False, quickstart: Union[QuickstartChoice, str, None] = None, config: Union[Dict, MemGPTConfig] = None, # not the same thing as AgentConfig @@ -24,7 +26,6 @@ class Client(object): :param config: optional config settings to apply after quickstart :param debug: indicates whether to display debug messages. """ - self.user_id = "null" self.auto_save = auto_save # make sure everything is set up properly @@ -56,6 +57,7 @@ class Client(object): if config is not None: set_config_with_dict(config) + self.user_id = MemGPTConfig.load().anon_clientid if user_id is None else user_id self.interface = QueuingInterface(debug=debug) self.server = SyncServer(default_interface=self.interface) @@ -69,24 +71,22 @@ class Client(object): def create_agent( self, - agent_config: Union[Dict, AgentConfig], - persistence_manager: Union[PersistenceManager, None] = None, - throw_if_exists: bool = False, - ) -> str: + agent_config: dict, + # persistence_manager: Union[PersistenceManager, None] = None, + throw_if_exists: bool = True, + ) -> AgentState: if isinstance(agent_config, dict): agent_name = agent_config.get("name") else: - agent_name = agent_config.name + raise TypeError(f"agent_config must be of type dict") if not self.agent_exists(agent_id=agent_name): self.interface.clear() - return self.server.create_agent(user_id=self.user_id, agent_config=agent_config, persistence_manager=persistence_manager) - - if throw_if_exists: + agent_state = self.server.create_agent(user_id=self.user_id, agent_config=agent_config) + return agent_state + else: raise ValueError(f"Agent {agent_name} already exists") - return agent_name - def get_agent_config(self, agent_id: str) -> Dict: self.interface.clear() return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id) diff --git a/memgpt/config.py b/memgpt/config.py index 65da99c5..a39c0489 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -9,8 +9,8 @@ import memgpt import memgpt.utils as utils from memgpt.utils import printd, get_schema_diff from memgpt.functions.functions import load_all_function_sets -from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS, DEFAULT_HUMAN, DEFAULT_PERSONA -from memgpt.presets.presets import DEFAULT_PRESET +from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET +from memgpt.data_types import AgentState, User, LLMConfig, EmbeddingConfig # helper functions for writing to configs @@ -31,6 +31,119 @@ def set_field(config, section, field, value): config.set(section, field, value) +@dataclass +class Config: + # system config for MemGPT + config_path = os.path.join(MEMGPT_DIR, "config") + anon_clientid = None + + # database configs: archival + archival_storage_type: str = "chroma" # local, db + archival_storage_path: str = os.path.join(MEMGPT_DIR, "chroma") + archival_storage_uri: str = None # TODO: eventually allow external vector DB + + # database configs: recall + recall_storage_type: str = "sqlite" # local, db + recall_storage_path: str = MEMGPT_DIR + recall_storage_uri: str = None # TODO: eventually allow external vector DB + + # database configs: metadata storage (sources, agents, data sources) + metadata_storage_type: str = "sqlite" + metadata_storage_path: str = MEMGPT_DIR + metadata_storage_uri: str = None + + memgpt_version: str = None + + @classmethod + def load(cls) -> "MemGPTConfig": + config = configparser.ConfigParser() + + # allow overriding with env variables + if os.getenv("MEMGPT_CONFIG_PATH"): + config_path = os.getenv("MEMGPT_CONFIG_PATH") + else: + config_path = MemGPTConfig.config_path + + if os.path.exists(config_path): + # read existing config + config.read(config_path) + config_dict = { + "archival_storage_type": get_field(config, "archival_storage", "type"), + "archival_storage_path": get_field(config, "archival_storage", "path"), + "archival_storage_uri": get_field(config, "archival_storage", "uri"), + "recall_storage_type": get_field(config, "recall_storage", "type"), + "recall_storage_path": get_field(config, "recall_storage", "path"), + "recall_storage_uri": get_field(config, "recall_storage", "uri"), + "metadata_storage_type": get_field(config, "metadata_storage", "type"), + "metadata_storage_path": get_field(config, "metadata_storage", "path"), + "metadata_storage_uri": get_field(config, "metadata_storage", "uri"), + "anon_clientid": get_field(config, "client", "anon_clientid"), + "config_path": config_path, + "memgpt_version": get_field(config, "version", "memgpt_version"), + } + config_dict = {k: v for k, v in config_dict.items() if v is not None} + return cls(**config_dict) + + # create new config + anon_clientid = str(uuid.uuid()) + config = cls(anon_clientid=anon_clientid, config_path=config_path) + config.save() # save updated config + return config + + def save(self): + import memgpt + + config = configparser.ConfigParser() + # archival storage + set_field(config, "archival_storage", "type", self.archival_storage_type) + set_field(config, "archival_storage", "path", self.archival_storage_path) + set_field(config, "archival_storage", "uri", self.archival_storage_uri) + + # recall storage + set_field(config, "recall_storage", "type", self.recall_storage_type) + set_field(config, "recall_storage", "path", self.recall_storage_path) + set_field(config, "recall_storage", "uri", self.recall_storage_uri) + + # metadata storage + set_field(config, "metadata_storage", "type", self.metadata_storage_type) + set_field(config, "metadata_storage", "path", self.metadata_storage_path) + set_field(config, "metadata_storage", "uri", self.metadata_storage_uri) + + # set version + set_field(config, "version", "memgpt_version", memgpt.__version__) + + # client + if not self.anon_clientid: + self.anon_clientid = str(uuid.uuid()) + set_field(config, "client", "anon_clientid", self.anon_clientid) + + if not os.path.exists(MEMGPT_DIR): + os.makedirs(MEMGPT_DIR, exist_ok=True) + with open(self.config_path, "w") as f: + config.write(f) + + @staticmethod + def exists(): + # allow overriding with env variables + if os.getenv("MEMGPT_CONFIG_PATH"): + config_path = os.getenv("MEMGPT_CONFIG_PATH") + else: + config_path = MemGPTConfig.config_path + + assert not os.path.isdir(config_path), f"Config path {config_path} cannot be set to a directory." + return os.path.exists(config_path) + + @staticmethod + def create_config_dir(): + if not os.path.exists(MEMGPT_DIR): + os.makedirs(MEMGPT_DIR, exist_ok=True) + + folders = ["functions", "system_prompts", "presets", "settings"] + for folder in folders: + if not os.path.exists(os.path.join(MEMGPT_DIR, folder)): + os.makedirs(os.path.join(MEMGPT_DIR, folder)) + + @dataclass class MemGPTConfig: config_path: str = os.path.join(MEMGPT_DIR, "config") @@ -354,6 +467,17 @@ class AgentConfig: with open(self.agent_config_path, "w") as f: json.dump(vars(self), f, indent=4) + def to_agent_state(self): + return AgentState( + name=self.name, + preset=self.preset, + persona=self.persona, + human=self.human, + llm_config=self.llm_config, + embedding_config=self.embedding_config, + create_time=self.create_time, + ) + @staticmethod def exists(name: str): """Check if agent config exists""" diff --git a/memgpt/connectors/local.py b/memgpt/connectors/local.py deleted file mode 100644 index dc554952..00000000 --- a/memgpt/connectors/local.py +++ /dev/null @@ -1,314 +0,0 @@ -from typing import Optional, List, Iterator -import shutil -from memgpt.config import AgentConfig, MemGPTConfig -from tqdm import tqdm -import re -import pickle -import os - -import json -import glob -from typing import List, Optional, Dict -from abc import abstractmethod - -from llama_index import VectorStoreIndex, ServiceContext, set_global_service_context -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.empty.base import EmptyIndex -from llama_index.retrievers import VectorIndexRetriever -from llama_index.schema import TextNode - -from memgpt.constants import MEMGPT_DIR -from memgpt.data_types import Record -from memgpt.config import MemGPTConfig -from memgpt.connectors.storage import StorageConnector, TableType -from memgpt.config import AgentConfig, MemGPTConfig -from memgpt.utils import printd, get_local_time, parse_formatted_time -from memgpt.data_types import Message, Passage, Record - -# class VectorIndexStorageConnector(StorageConnector): - -# """Local storage connector based on LlamaIndex""" - -# def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): -# super().__init__(table_type=table_type, agent_config=agent_config) -# config = MemGPTConfig.load() - -## TODO: add asserts to avoid both being passed -# if agent_config is not None: -# self.name = agent_config.name -# self.save_directory = agent_config.save_agent_index_dir() -# else: -# self.name = name -# self.save_directory = f"{MEMGPT_DIR}/archival/{name}" - -## llama index contexts -# self.embed_model = embedding_model() -# self.service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model, chunk_size=config.embedding_chunk_size) -# set_global_service_context(self.service_context) - -## load/create index -# self.save_path = f"{self.save_directory}/nodes.pkl" -# if os.path.exists(self.save_path): -# self.nodes = pickle.load(open(self.save_path, "rb")) -# else: -# self.nodes = [] - -## create vectorindex -# if len(self.nodes): -# self.index = VectorStoreIndex(self.nodes) -# else: -# self.index = EmptyIndex() - -# def load(self, filters: Dict): -## load correct version based off filters -# if "agent_id" in filters and filters["agent_id"] is not None: -## load agent archival memory -# save_directory = self.agent_config.save_agent_index_dir() -# elif "data_source" in filters and filters["data_source"] is not None: -# name = filters["data_source"] -# save_directory = f"{MEMGPT_DIR}/archival/{name}" -# else: -# raise ValueError(f"Cannot load index without agent_id or data_source {filters}") -# save_path = f"{save_directory}/nodes.pkl" -# if os.path.exists(save_path): -# nodes = pickle.load(open(save_path, "rb")) -# else: -# nodes = [] -## create vectorindex -# if len(self.nodes): -# self.index = VectorStoreIndex(self.nodes) -# else: -# self.index = EmptyIndex() - - -# def get_nodes(self) -> List[TextNode]: -# """Get llama index nodes""" -# embed_dict = self.index._vector_store._data.embedding_dict -# node_dict = self.index._docstore.docs - -# nodes = [] -# for node_id, node in node_dict.items(): -# vector = embed_dict[node_id] -# node.embedding = vector -# nodes.append(TextNode(text=node.text, embedding=vector)) -# return nodes - -# def add_nodes(self, nodes: List[TextNode]): -# self.nodes += nodes -# self.index = VectorStoreIndex(self.nodes) - -# def get_all_paginated(self, page_size: int = 100) -> Iterator[List[Passage]]: -# """Get all passages in the index""" -# nodes = self.get_nodes() -# for i in tqdm(range(0, len(nodes), page_size)): -# yield [Passage(text=node.text, embedding=node.embedding) for node in nodes[i : i + page_size]] - -# def get_all(self, limit: int) -> List[Passage]: -# passages = [] -# for node in self.get_nodes(): -# assert node.embedding is not None, f"Node embedding is None" -# passages.append(Passage(text=node.text, embedding=node.embedding)) -# if len(passages) >= limit: -# break -# return passages - -# def get(self, id: str) -> Passage: -# pass - -# def insert(self, passage: Passage): -# nodes = [TextNode(text=passage.text, embedding=passage.embedding)] -# self.nodes += nodes -# if isinstance(self.index, EmptyIndex): -# self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True) -# else: -# self.index.insert_nodes(nodes) - -# def insert_many(self, passages: List[Passage]): -# nodes = [TextNode(text=passage.text, embedding=passage.embedding) for passage in passages] -# self.nodes += nodes -# if isinstance(self.index, EmptyIndex): -# self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True) -# else: -# orig_size = len(self.get_nodes()) -# self.index.insert_nodes(nodes) -# assert len(self.get_nodes()) == orig_size + len( -# passages -# ), f"expected {orig_size + len(passages)} nodes, got {len(self.get_nodes())} nodes" - -# def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]: -# if isinstance(self.index, EmptyIndex): # empty index -# return [] -## TODO: this may be super slow? -## the nice thing about creating this here is that now we can save the persistent storage manager -# retriever = VectorIndexRetriever( -# index=self.index, # does this get refreshed? -# similarity_top_k=top_k, -# ) -# nodes = retriever.retrieve(query) -# results = [Passage(embedding=node.embedding, text=node.text) for node in nodes] -# return results - -# def save(self): -## assert len(self.nodes) == len(self.get_nodes()), f"Expected {len(self.nodes)} nodes, got {len(self.get_nodes())} nodes" -# self.nodes = self.get_nodes() -# os.makedirs(self.save_directory, exist_ok=True) -# pickle.dump(self.nodes, open(self.save_path, "wb")) - -# @staticmethod -# def list_loaded_data(): -# sources = [] -# for data_source_file in os.listdir(os.path.join(MEMGPT_DIR, "archival")): -# name = os.path.basename(data_source_file) -# sources.append(name) -# return sources - -# def size(self): -# return len(self.get_nodes()) - - -class InMemoryStorageConnector(StorageConnector): - """Really dumb class so we can have a unified storae connector interface - keeps everything in memory""" - - """ Backwards compatible with previous version of recall memory """ - - # TODO: maybae replace this with sqllite? - - def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): - super().__init__(table_type=table_type, agent_config=agent_config) - config = MemGPTConfig.load() - - # supported table types - self.supported_types = [TableType.RECALL_MEMORY] - if table_type not in self.supported_types: - raise ValueError(f"Table type {table_type} not supported by InMemoryStorageConnector") - - # TODO: load if exists - self.agent_config = agent_config - if agent_config is None: - # is a data source - raise ValueError("Cannot load data source from InMemoryStorageConnector") - else: - directory = agent_config.save_state_dir() - if os.path.exists(directory): - print(f"Loading saved agent {agent_config.name} from {directory}") - json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory. - if not json_files: - print(f"/load error: no .json checkpoint files found") - raise ValueError(f"Cannot load {agent_config.name} - no saved checkpoints found in {directory}") - - # Sort files based on modified timestamp, with the latest file being the first. - filename = max(json_files, key=os.path.getmtime) - state = json.load(open(filename, "r")) - - # load persistence manager - filename = os.path.basename(filename).replace(".json", ".persistence.pickle") - directory = agent_config.save_persistence_manager_dir() - printd(f"Loading persistence manager from {os.path.join(directory, filename)}") - with open(filename, "rb") as f: - data = pickle.load(f) - self.rows = data["all_messages"] - else: - print(f"Creating new agent {agent_config.name}") - self.rows = [] - - # convert to Record class - self.rows = [self.json_to_message(m) for m in self.rows] - - def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: - offset = 0 - while True: - yield self.rows[offset : offset + page_size] - offset += page_size - if offset >= len(self.rows): - break - - def get_all(self, limit: Optional[int] = None, filters: Optional[Dict] = {}) -> List[Record]: - if limit: - return self.rows[:limit] - return self.rows - - def get(self, id: str) -> Record: - match_row = [row for row in self.rows if row.id == id] - if len(match_row) == 0: - return None - assert len(match_row) == 1, f"Expected 1 match, got {len(match_row)} matches" - return match_row[0] - - def insert(self, record: Record): - self.rows.append(record) - - def insert_many(self, records: List[Record]): - self.rows += records - - def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: - raise NotImplementedError - - def json_to_message(self, message_json) -> Message: - """Convert agent message JSON into Message object""" - timestamp = message_json["timestamp"] - message = message_json["message"] - - return Message( - user_id=self.config.anon_clientid, - agent_id=self.agent_config.name, - role=message["role"], - text=message["content"], - model=self.agent_config.model, - created_at=parse_formatted_time(timestamp), - function_name=message["function_name"] if "function_name" in message else None, - function_args=message["function_args"] if "function_args" in message else None, - function_response=message["function_response"] if "function_response" in message else None, - id=message["id"] if "id" in message else None, - ) - - def message_to_json(self, message: Message) -> Dict: - """Convert Message object into JSON""" - return { - "timestamp": message.created_at.strftime("%Y-%m-%d %H:%M:%S %Z%z"), - "message": { - "role": message.role, - "content": message.text, - "function_name": message.function_name, - "function_args": message.function_args, - "function_response": message.function_response, - "id": message.id, - }, - } - - def save(self): - """Save state of storage connector""" - timestamp = get_local_time().replace(" ", "_").replace(":", "_") - filename = f"{timestamp}.persistence.pickle" - os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True) - filename = os.path.join(self.config.save_persistence_manager_dir(), filename) - - all_messages = [self.message_to_json(m) for m in self.rows] - - with open(filename, "wb") as fh: - ## TODO: fix this hacky solution to pickle the retriever - pickle.dump( - { - "all_messages": all_messages, - }, - fh, - protocol=pickle.HIGHEST_PROTOCOL, - ) - printd(f"Saved state to {fh}") - - def size(self, filters: Optional[Dict] = {}) -> int: - return len(self.rows) - - def query_date(self, start_date, end_date) -> List[Record]: - return [row for row in self.rows if row.created_at >= start_date and row.created_at <= end_date] - - def query_text(self, query: str) -> List[Record]: - return [row for row in self.rows if row.role not in ["system", "function"] and query.lower() in row.text.lower()] - - def delete(self, filters: Optional[Dict] = {}): - raise NotImplementedError - - def delete_table(self, filters: Optional[Dict] = {}): - if os.path.exists(self.agent_config.save_state_dir()): - shutil.rmtree(self.agent_config.save_state_dir()) - if os.path.exists(self.agent_config.save_persistence_manager_dir()): - shutil.rmtree(self.agent_config.save_persistence_manager_dir()) diff --git a/memgpt/constants.py b/memgpt/constants.py index 9710e8d7..a5806432 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -5,6 +5,7 @@ MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt") DEFAULT_MEMGPT_MODEL = "gpt-4" DEFAULT_PERSONA = "sam_pov" DEFAULT_HUMAN = "basic" +DEFAULT_PRESET = "memgpt_chat" FIRST_MESSAGE_ATTEMPTS = 10 diff --git a/memgpt/data_types.py b/memgpt/data_types.py index d2768bd7..821a1fc9 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -1,9 +1,12 @@ """ This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """ import uuid +from datetime import datetime from abc import abstractmethod from typing import Optional, List, Dict import numpy as np +from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS +from memgpt.utils import get_local_time, format_datetime # Defining schema objects: # Note: user/agent can borrow from MemGPTConfig/AgentConfig classes @@ -54,7 +57,7 @@ class Message(Record): role: str, text: str, model: str, # model used to make function call - user: Optional[str] = None, # optional participant name + name: Optional[str] = None, # optional participant name created_at: Optional[str] = None, tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested tool_call_id: Optional[str] = None, @@ -70,7 +73,7 @@ class Message(Record): # openai info self.role = role # role (agent/user/function) - self.user = user + self.name = name # tool (i.e. function) call info (optional) @@ -134,15 +137,303 @@ class Passage(Record): # pass -class Source(Record): +class LLMConfig: + def __init__( + self, + model: Optional[str] = "gpt-4", + model_endpoint_type: Optional[str] = "openai", + model_endpoint: Optional[str] = "https://api.openai.com/v1", + model_wrapper: Optional[str] = None, + context_window: Optional[int] = None, + ): + self.model = model + self.model_endpoint_type = model_endpoint_type + self.model_endpoint = model_endpoint + self.model_wrapper = model_wrapper + self.context_window = context_window + + if context_window is None: + self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"] + else: + self.context_window = context_window + + +class OpenAILLMConfig(LLMConfig): + def __init__(self, openai_key, **kwargs): + super().__init__(**kwargs) + self.openai_key = openai_key + + +class AzureLLMConfig(LLMConfig): + def __init__( + self, + azure_key: Optional[str] = None, + azure_endpoint: Optional[str] = None, + azure_version: Optional[str] = None, + azure_deployment: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.azure_key = azure_key + self.azure_endpoint = azure_endpoint + self.azure_version = azure_version + self.azure_deployment = azure_deployment + + +class EmbeddingConfig: + def __init__( + self, + embedding_endpoint_type: Optional[str] = "local", + embedding_endpoint: Optional[str] = None, + embedding_model: Optional[str] = None, + embedding_dim: Optional[int] = 384, + embedding_chunk_size: Optional[int] = 300, + # openai-only + openai_key: Optional[str] = None, + # azure-only + azure_key: Optional[str] = None, + azure_endpoint: Optional[str] = None, + azure_version: Optional[str] = None, + azure_deployment: Optional[str] = None, + ): + self.embedding_endpoint_type = embedding_endpoint_type + self.embedding_endpoint = embedding_endpoint + self.embedding_model = embedding_model + self.embedding_dim = embedding_dim + self.embedding_chunk_size = embedding_chunk_size + + # openai + self.openai_key = openai_key + + # azure + self.azure_key = azure_key + self.azure_endpoint = azure_endpoint + self.azure_version = azure_version + self.azure_deployment = azure_deployment + + +class OpenAIEmbeddingConfig(EmbeddingConfig): + def __init__(self, openai_key: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.openai_key = openai_key + + +class AzureEmbeddingConfig(EmbeddingConfig): + def __init__( + self, + azure_key: Optional[str] = None, + azure_endpoint: Optional[str] = None, + azure_version: Optional[str] = None, + azure_deployment: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.azure_key = azure_key + self.azure_endpoint = azure_endpoint + self.azure_version = azure_version + self.azure_deployment = azure_deployment + + +class User: + + """Defines user and default configurations""" + + # TODO: make sure to encrypt/decrypt keys before storing in DB + + def __init__( + self, + id: Optional[uuid.UUID] = None, + default_preset=DEFAULT_PRESET, + default_persona=DEFAULT_PERSONA, + default_human=DEFAULT_HUMAN, + default_agent=None, + default_llm_config: Optional[LLMConfig] = None, # defaults: llm model + default_embedding_config: Optional[EmbeddingConfig] = None, # defaults: embeddings + # azure information + azure_key=None, + azure_endpoint=None, + azure_version=None, + azure_deployment=None, + # openai information + openai_key=None, + # other + policies_accepted=False, + ): + if id is None: + self.id = uuid.uuid4() + else: + self.id = id + + self.default_preset = default_preset + self.default_persona = default_persona + self.default_human = default_human + self.default_agent = default_agent + + # model defaults + self.default_llm_config = default_llm_config if default_llm_config is not None else LLMConfig() + self.default_embedding_config = default_embedding_config if default_embedding_config is not None else EmbeddingConfig() + + # azure information + # TODO: split this up accross model config and embedding config? + self.azure_key = azure_key + self.azure_endpoint = azure_endpoint + self.azure_version = azure_version + self.azure_deployment = azure_deployment + + # openai information + self.openai_key = openai_key + + # set default embedding config + if default_embedding_config is None: + if self.openai_key: + self.default_embedding_config = OpenAIEmbeddingConfig( + openai_key=self.openai_key, + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + ) + elif self.azure_key: + self.default_embedding_config = AzureEmbeddingConfig( + azure_key=self.azure_key, + azure_endpoint=self.azure_endpoint, + azure_version=self.azure_version, + azure_deployment=self.azure_deployment, + embedding_endpoint_type="azure", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + ) + else: + # memgpt hosted + self.default_embedding_config = EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="BAAI/bge-large-en-v1.5", + embedding_dim=1024, + embedding_chunk_size=300, + ) + + # set default LLM config + if default_llm_config is None: + if self.openai_key: + self.default_llm_config = OpenAILLMConfig( + openai_key=self.openai_key, + model="gpt-4", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model_wrapper=None, + context_window=LLM_MAX_TOKENS["gpt-4"], + ) + elif self.azure_key: + self.default_llm_config = AzureLLMConfig( + azure_key=self.azure_key, + azure_endpoint=self.azure_endpoint, + azure_version=self.azure_version, + azure_deployment=self.azure_deployment, + model="gpt-4", + model_endpoint_type="azure", + model_endpoint="https://api.openai.com/v1", + model_wrapper=None, + context_window=LLM_MAX_TOKENS["gpt-4"], + ) + else: + # memgpt hosted + self.default_llm_config = LLMConfig( + model="ehartford/dolphin-2.5-mixtral-8x7b", + model_endpoint_type="vllm", + model_endpoint="https://api.memgpt.ai", + model_wrapper="chatml", + context_window=16384, + ) + + # misc + self.policies_accepted = policies_accepted + + +class AgentState: + def __init__( + self, + name: str, + user_id: str, + persona: str, # the filename where the persona was originally sourced from + human: str, # the filename where the human was originally sourced from + llm_config: LLMConfig, + embedding_config: EmbeddingConfig, + preset: str, + # (in-context) state contains: + # persona: str # the current persona text + # human: str # the current human text + # system: str, # system prompt (not required if initializing with a preset) + # functions: dict, # schema definitions ONLY (function code linked at runtime) + # messages: List[dict], # in-context messages + id: Optional[uuid.UUID] = None, + state: Optional[dict] = None, + created_at: Optional[str] = None, + ): + if id is None: + self.id = uuid.uuid4() + else: + self.id = id + + # TODO(swooders) we need to handle the case where name is None here + # in AgentConfig we autogenerate a name, not sure what the correct thing w/ DBs is, what about NounAdjective combos? Like giphy does? BoredGiraffe etc + self.name = name + self.user_id = user_id + self.preset = preset + self.persona = persona + self.human = human + + self.llm_config = llm_config + self.embedding_config = embedding_config + + self.created_at = created_at if created_at is not None else datetime.now() + + # state + self.state = state + + # def __eq__(self, other): + # if not isinstance(other, AgentState): + # # return False + # return NotImplemented + + # return ( + # self.name == other.name + # and self.user_id == other.user_id + # and self.persona == other.persona + # and self.human == other.human + # and vars(self.llm_config) == vars(other.llm_config) + # and vars(self.embedding_config) == vars(other.embedding_config) + # and self.preset == other.preset + # and self.state == other.state + # ) + + # def __dict__(self): + # return { + # "id": self.id, + # "name": self.name, + # "user_id": self.user_id, + # "preset": self.preset, + # "persona": self.persona, + # "human": self.human, + # "llm_config": self.llm_config, + # "embedding_config": self.embedding_config, + # "created_at": format_datetime(self.created_at), + # "state": self.state, + # } + + +class Source: def __init__( self, user_id: str, name: str, created_at: Optional[str] = None, - id: Optional[str] = None, + id: Optional[uuid.UUID] = None, ): - super().__init__(id) + if id is None: + self.id = uuid.uuid4() + else: + self.id = id + self.name = name self.user_id = user_id self.created_at = created_at diff --git a/memgpt/embeddings.py b/memgpt/embeddings.py index 396297a0..299266b4 100644 --- a/memgpt/embeddings.py +++ b/memgpt/embeddings.py @@ -1,8 +1,10 @@ import typer +import uuid from typing import Optional, List import os from memgpt.utils import is_valid_url +from memgpt.data_types import EmbeddingConfig from llama_index.embeddings import OpenAIEmbedding, AzureOpenAIEmbedding from llama_index.embeddings import TextEmbeddingsInference @@ -134,19 +136,14 @@ class EmbeddingEndpoint(BaseEmbedding): return self._get_text_embedding(text) -def embedding_model(): +def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None): """Return LlamaIndex embedding model to use for embeddings""" - from memgpt.config import MemGPTConfig - - # load config - config = MemGPTConfig.load() endpoint_type = config.embedding_endpoint_type if endpoint_type == "openai": - model = OpenAIEmbedding( - api_base=config.embedding_endpoint, api_key=config.openai_key, additional_kwargs={"user": config.anon_clientid} - ) + additional_kwargs = {"user_id": user_id} if user_id else {} + model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=config.openai_key, additional_kwargs=additional_kwargs) return model elif endpoint_type == "azure": # https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings @@ -160,7 +157,7 @@ def embedding_model(): api_version=config.azure_version, ) elif endpoint_type == "hugging-face": - embed_model = EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=config.anon_clientid) + embed_model = EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id) return embed_model else: # default to hugging face model running local diff --git a/memgpt/main.py b/memgpt/main.py index 2e0ef096..dd0edca5 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -18,6 +18,7 @@ from prettytable import PrettyTable console = Console() from memgpt.interface import CLIInterface as interface # for printing to terminal +from memgpt.config import MemGPTConfig import memgpt.agent as agent import memgpt.system as system import memgpt.constants as constants @@ -25,7 +26,8 @@ import memgpt.errors as errors from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart, suppress_stdout from memgpt.cli.cli_config import configure, list, add, delete from memgpt.cli.cli_load import app as load_app -from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.agent_store.storage import StorageConnector, TableType +from memgpt.metadata import MetadataStore app = typer.Typer(pretty_exceptions_enable=False) app.command(name="run")(run) @@ -52,7 +54,7 @@ def clear_line(strip_ui=False): sys.stdout.flush() -def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=False): +def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, no_verify=False, cfg=None, strip_ui=False): counter = 0 user_input = None skip_next_user_input = False @@ -65,7 +67,7 @@ def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=Fals print() multiline_input = False - metadata_db = StorageConnector.get_metadata_storage_connector(table_type=TableType.DATA_SOURCES) # already filters by user + ms = MetadataStore(config) while True: if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): # Ask for user input @@ -104,7 +106,8 @@ def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=Fals elif user_input.lower() == "/attach": # TODO: check if agent already has it - data_source_options = [row.name for row in metadata_db.get_all()] + data_source_options = ms.list_sources(user_id=memgpt_agent.agent_state.user_id) + data_source_options = [s.name for s in data_source_options] if len(data_source_options) == 0: typer.secho( 'No sources available. You must load a souce with "memgpt load ..." before running /attach.', @@ -117,16 +120,6 @@ def run_agent_loop(memgpt_agent, first, no_verify=False, cfg=None, strip_ui=Fals # attach new data attach(memgpt_agent.config.name, data_source) - # update agent config - memgpt_agent.config.attach_data_source(data_source) - - # reload agent with new data source - # TODO: maybe make this less ugly... - with suppress_stdout(): - memgpt_agent.persistence_manager.archival_memory.storage = StorageConnector.get_archival_storage_connector( - agent_config=memgpt_agent.config - ) - # TODO: update metadata_db to record attached agents continue elif user_input.lower() == "/dump" or user_input.lower().startswith("/dump "): diff --git a/memgpt/memory.py b/memgpt/memory.py index 59febe4a..ce7a8e38 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -291,21 +291,20 @@ class BaseRecallMemory(RecallMemory): """Recall memory based on base functions implemented by storage connectors""" - def __init__(self, agent_config, restrict_search_to_summaries=False): + def __init__(self, agent_state, restrict_search_to_summaries=False): # If true, the pool of messages that can be queried are the automated summaries only # (generated when the conversation window needs to be shortened) self.restrict_search_to_summaries = restrict_search_to_summaries - from memgpt.connectors.storage import StorageConnector + from memgpt.agent_store.storage import StorageConnector - self.agent_config = agent_config - config = MemGPTConfig.load() + self.agent_state = agent_state # create embedding model - self.embed_model = embedding_model() - self.embedding_chunk_size = config.embedding_chunk_size + self.embed_model = embedding_model(agent_state.embedding_config) + self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size # create storage backend - self.storage = StorageConnector.get_recall_storage_connector(agent_config=agent_config) + self.storage = StorageConnector.get_recall_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id) # TODO: have some mechanism for cleanup otherwise will lead to OOM self.cache = {} @@ -352,31 +351,30 @@ class BaseRecallMemory(RecallMemory): class EmbeddingArchivalMemory(ArchivalMemory): """Archival memory with embedding based search""" - def __init__(self, agent_config, top_k: Optional[int] = 100): + def __init__(self, agent_state, top_k: Optional[int] = 100): """Init function for archival memory :param archival_memory_database: name of dataset to pre-fill archival with :type archival_memory_database: str """ - from memgpt.connectors.storage import StorageConnector + from memgpt.agent_store.storage import StorageConnector self.top_k = top_k - self.agent_config = agent_config - self.config = MemGPTConfig.load() + self.agent_state = agent_state # create embedding model - self.embed_model = embedding_model() - self.embedding_chunk_size = self.config.embedding_chunk_size + self.embed_model = embedding_model(agent_state.embedding_config) + self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size # create storage backend - self.storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config) + self.storage = StorageConnector.get_archival_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id) # TODO: have some mechanism for cleanup otherwise will lead to OOM self.cache = {} def create_passage(self, text, embedding): return Passage( - user_id=self.config.anon_clientid, - agent_id=self.agent_config.name, + user_id=self.agent_state.user_id, + agent_id=self.agent_state.id, text=text, embedding=embedding, ) @@ -447,7 +445,7 @@ class EmbeddingArchivalMemory(ArchivalMemory): for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10 passages.append(str(passage.text)) memory_str = "\n".join(passages) - return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" + return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" + f"\nSize: {self.storage.size()}" def __len__(self): return self.storage.size() diff --git a/memgpt/metadata.py b/memgpt/metadata.py new file mode 100644 index 00000000..a2807cab --- /dev/null +++ b/memgpt/metadata.py @@ -0,0 +1,364 @@ +""" Metadata store for user/agent/data_source information""" +import os +from typing import Optional, List, Dict +from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS +from memgpt.utils import get_local_time +from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig +from memgpt.config import MemGPTConfig + +from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean +from sqlalchemy import func +from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base +from sqlalchemy.orm.session import close_all_sessions +from sqlalchemy.sql import func +from sqlalchemy import Column, BIGINT, String, DateTime +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy_json import mutable_json_type, MutableJson +from sqlalchemy import TypeDecorator, CHAR +import uuid + + +from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base + + +Base = declarative_base() + + +# Custom UUID type +class CommonUUID(TypeDecorator): + impl = CHAR + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(UUID(as_uuid=True)) + else: + return dialect.type_descriptor(CHAR()) + + def process_bind_param(self, value, dialect): + if dialect.name == "postgresql" or value is None: + return value + else: + return str(value) # Convert UUID to string for SQLite + + def process_result_value(self, value, dialect): + if dialect.name == "postgresql" or value is None: + return value + else: + return uuid.UUID(value) + + +class LLMConfigColumn(TypeDecorator): + """Custom type for storing LLMConfig as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value: + return vars(value) + return value + + def process_result_value(self, value, dialect): + if value: + return LLMConfig(**value) + return value + + +class EmbeddingConfigColumn(TypeDecorator): + """Custom type for storing EmbeddingConfig as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value: + return vars(value) + return value + + def process_result_value(self, value, dialect): + if value: + return EmbeddingConfig(**value) + return value + + +class UserModel(Base): + __tablename__ = "users" + __table_args__ = {"extend_existing": True} + + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + default_preset = Column(String) + default_persona = Column(String) + default_human = Column(String) + default_agent = Column(String) + + default_llm_config = Column(LLMConfigColumn) + default_embedding_config = Column(EmbeddingConfigColumn) + + azure_key = Column(String, nullable=True) + azure_endpoint = Column(String, nullable=True) + azure_version = Column(String, nullable=True) + azure_deployment = Column(String, nullable=True) + + openai_key = Column(String, nullable=True) + policies_accepted = Column(Boolean, nullable=False, default=False) + + def __repr__(self) -> str: + return f"" + + def to_record(self) -> User: + return User( + id=self.id, + default_preset=self.default_preset, + default_persona=self.default_persona, + default_human=self.default_human, + default_agent=self.default_agent, + default_llm_config=self.default_llm_config, + default_embedding_config=self.default_embedding_config, + azure_key=self.azure_key, + azure_endpoint=self.azure_endpoint, + azure_version=self.azure_version, + azure_deployment=self.azure_deployment, + openai_key=self.openai_key, + policies_accepted=self.policies_accepted, + ) + + +class AgentModel(Base): + """Defines data model for storing Passages (consisting of text, embedding)""" + + __tablename__ = "agents" + __table_args__ = {"extend_existing": True} + + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + user_id = Column(CommonUUID, nullable=False) + name = Column(String, nullable=False) + persona = Column(String) + human = Column(String) + preset = Column(String) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + # configs + llm_config = Column(LLMConfigColumn) + embedding_config = Column(EmbeddingConfigColumn) + + # state + state = Column(JSON) + + def __repr__(self) -> str: + return f"" + + def to_record(self) -> AgentState: + return AgentState( + id=self.id, + user_id=self.user_id, + name=self.name, + persona=self.persona, + human=self.human, + preset=self.preset, + created_at=self.created_at, + llm_config=self.llm_config, + embedding_config=self.embedding_config, + state=self.state, + ) + + +class SourceModel(Base): + """Defines data model for storing Passages (consisting of text, embedding)""" + + __tablename__ = "sources" + __table_args__ = {"extend_existing": True} + + # Assuming passage_id is the primary key + # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + user_id = Column(CommonUUID, nullable=False) + name = Column(String, nullable=False) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + # TODO: add num passages + + def __repr__(self) -> str: + return f"" + + def to_record(self) -> Source: + return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at) + + +class AgentSourceMappingModel(Base): + + """Stores mapping between agent -> source""" + + __tablename__ = "agent_source_mapping" + + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + user_id = Column(CommonUUID, nullable=False) + agent_id = Column(CommonUUID, nullable=False) + source_id = Column(CommonUUID, nullable=False) + + def __repr__(self) -> str: + return f"" + + +class MetadataStore: + def __init__(self, config: MemGPTConfig): + # TODO: get DB URI or path + if config.metadata_storage_type == "postgres": + self.uri = config.metadata_storage_uri + elif config.metadata_storage_type == "sqlite": + path = os.path.join(config.metadata_storage_path, "sqlite.db") + self.uri = f"sqlite:///{path}" + else: + raise ValueError(f"Invalid metadata storage type: {config.metadata_storage_type}") + + # TODO: check to see if table(s) need to be greated or not + + self.engine = create_engine(self.uri) + Base.metadata.create_all( + self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__] + ) + self.Session = sessionmaker(bind=self.engine) + + def create_agent(self, agent: AgentState): + # insert into agent table + session = self.Session() + # make sure agent.name does not already exist for user user_id + if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0: + raise ValueError(f"Agent with name {agent.name} already exists") + session.add(AgentModel(**vars(agent))) + session.commit() + + def create_source(self, source: Source): + session = self.Session() + # make sure source.name does not already exist for user + if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0: + raise ValueError(f"Source with name {source.name} already exists") + session.add(SourceModel(**vars(source))) + session.commit() + + def create_user(self, user: User): + session = self.Session() + if session.query(UserModel).filter(UserModel.id == user.id).count() > 0: + raise ValueError(f"User with id {user.id} already exists") + session.add(UserModel(**vars(user))) + session.commit() + + def update_agent(self, agent: AgentState): + session = self.Session() + session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent)) + session.commit() + + def update_user(self, user: User): + session = self.Session() + session.query(UserModel).filter(UserModel.id == user.id).update(vars(user)) + session.commit() + + def update_source(self, source: Source): + session = self.Session() + session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source)) + session.commit() + + def delete_agent(self, agent_id: str): + session = self.Session() + session.query(AgentModel).filter(AgentModel.id == agent_id).delete() + session.commit() + + def delete_source(self, source_id: str): + session = self.Session() + + # delete from sources table + session.query(SourceModel).filter(SourceModel.id == source_id).delete() + + # delete any mappings + session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete() + + session.commit() + + def delete_user(self, user_id: str): + session = self.Session() + + # delete from users table + session.query(UserModel).filter(UserModel.id == user_id).delete() + + # delete associated agents + session.query(AgentModel).filter(AgentModel.user_id == user_id).delete() + + # delete associated sources + session.query(SourceModel).filter(SourceModel.user_id == user_id).delete() + + # delete associated mappings + session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete() + + session.commit() + + def list_agents(self, user_id: str) -> List[AgentState]: + session = self.Session() + results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all() + return [r.to_record() for r in results] + + def list_sources(self, user_id: str) -> List[Source]: + session = self.Session() + results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all() + return [r.to_record() for r in results] + + def get_agent(self, agent_id: str = None, agent_name: str = None, user_id: str = None) -> Optional[AgentState]: + session = self.Session() + if agent_id: + results = session.query(AgentModel).filter(AgentModel.id == agent_id).all() + else: + assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name" + results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all() + + if len(results) == 0: + return None + assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result + return results[0].to_record() + + def get_user(self, user_id: str) -> Optional[User]: + session = self.Session() + results = session.query(UserModel).filter(UserModel.id == user_id).all() + if len(results) == 0: + return None + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + return results[0].to_record() + + def get_source(self, source_id: str = None, user_id: str = None, source_name: str = None) -> Optional[Source]: + session = self.Session() + if source_id: + results = session.query(SourceModel).filter(SourceModel.id == source_id).all() + else: + assert user_id is not None and source_name is not None + results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all() + if len(results) == 0: + return None + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + return results[0].to_record() + + # agent source metadata + def attach_source(self, user_id: str, agent_id: str, source_id: str): + session = self.Session() + session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id)) + session.commit() + + def list_attached_sources(self, agent_id: str) -> List[Column]: + session = self.Session() + results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all() + return [r.source_id for r in results] + + def list_attached_agents(self, source_id): + session = self.Session() + results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all() + return [r.agent_id for r in results] + + def detach_source(self, agent_id: str, source_id: str): + session = self.Session() + session.query(AgentSourceMappingModel).filter( + AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id + ).delete() + session.commit() diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py index 8e68449e..a6a6fefb 100644 --- a/memgpt/openai_tools.py +++ b/memgpt/openai_tools.py @@ -378,23 +378,25 @@ def create( config = MemGPTConfig.load() # load credentials (currently not stored in agent config) - printd(f"Using model {agent_config.model_endpoint_type}, endpoint: {agent_config.model_endpoint}") - if agent_config.model_endpoint_type == "openai": + printd(f"Using model {agent_config.llm_config.model_endpoint_type}, endpoint: {agent_config.llm_config.model_endpoint}") + if agent_config.llm_config.model_endpoint_type == "openai": # openai return openai_chat_completions_request( - url=agent_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + url=agent_config.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=config.openai_key, # 'sk....' data=dict( - model=agent_config.model, + model=agent_config.llm_config.model, messages=messages, functions=functions, function_call=function_call, user=config.anon_clientid, ), ) - elif agent_config.model_endpoint_type == "azure": + elif agent_config.llm_config.model_endpoint_type == "azure": # azure - azure_deployment = config.azure_deployment if config.azure_deployment is not None else MODEL_TO_AZURE_ENGINE[agent_config.model] + azure_deployment = ( + config.azure_deployment if config.azure_deployment is not None else MODEL_TO_AZURE_ENGINE[agent_config.llm_config.model] + ) return azure_openai_chat_completions_request( resource_name=config.azure_endpoint, deployment_id=azure_deployment, @@ -411,14 +413,14 @@ def create( ) else: # local model return get_chat_completion( - model=agent_config.model, + model=agent_config.llm_config.model, messages=messages, functions=functions, function_call=function_call, - context_window=agent_config.context_window, - endpoint=agent_config.model_endpoint, - endpoint_type=agent_config.model_endpoint_type, - wrapper=agent_config.model_wrapper, + context_window=agent_config.llm_config.context_window, + endpoint=agent_config.llm_config.model_endpoint, + endpoint_type=agent_config.llm_config.model_endpoint_type, + wrapper=agent_config.llm_config.model_wrapper, user=config.anon_clientid, # hint first_message=first_message, diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py index ad7621cc..b9980fd4 100644 --- a/memgpt/persistence_manager.py +++ b/memgpt/persistence_manager.py @@ -7,8 +7,7 @@ from memgpt.memory import ( EmbeddingArchivalMemory, ) from memgpt.utils import get_local_time, printd -from memgpt.data_types import Message, ToolCall -from memgpt.config import MemGPTConfig +from memgpt.data_types import Message, ToolCall, AgentState from datetime import datetime @@ -46,15 +45,14 @@ class LocalStateManager(PersistenceManager): recall_memory_cls = BaseRecallMemory archival_memory_cls = EmbeddingArchivalMemory - def __init__(self, agent_config: AgentConfig): + def __init__(self, agent_state: AgentState): # Memory held in-state useful for debugging stateful versions self.memory = None self.messages = [] # current in-context messages # self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB) - self.archival_memory = EmbeddingArchivalMemory(agent_config) - self.recall_memory = BaseRecallMemory(agent_config) - self.agent_config = agent_config - self.config = MemGPTConfig.load() + self.archival_memory = EmbeddingArchivalMemory(agent_state) + self.recall_memory = BaseRecallMemory(agent_state) + self.agent_state = agent_state @classmethod def load(cls, agent_config: AgentConfig): @@ -133,11 +131,12 @@ class LocalStateManager(PersistenceManager): tool_calls = None return Message( - user_id=self.config.anon_clientid, - agent_id=self.agent_config.name, + user_id=self.agent_state.user_id, + agent_id=self.agent_state.id, role=message["role"], text=message["content"], - model=self.agent_config.model, + name=message["name"] if "name" in message else None, + model=self.agent_state.llm_config.model, created_at=parse_formatted_time(timestamp), tool_calls=tool_calls, tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None, diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py index d688f0fb..ab698807 100644 --- a/memgpt/presets/presets.py +++ b/memgpt/presets/presets.py @@ -1,16 +1,33 @@ -from .utils import load_all_presets, is_valid_yaml_format -from ..prompts import gpt_functions -from ..prompts import gpt_system -from ..functions.functions import load_all_function_sets +from memgpt.data_types import AgentState +from memgpt.interface import AgentInterface +from memgpt.presets.utils import load_all_presets, is_valid_yaml_format +from memgpt.utils import get_human_text, get_persona_text +from memgpt.prompts import gpt_system +from memgpt.functions.functions import load_all_function_sets -DEFAULT_PRESET = "memgpt_chat" available_presets = load_all_presets() preset_options = list(available_presets.keys()) -def use_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): - """Storing combinations of SYSTEM + FUNCTION prompts""" +# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): +def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface): + """Initialize a new agent from a preset (combination of system + function)""" + + # Input validation + if agent_state.persona is None: + raise ValueError(f"'persona' not specified in AgentState (required)") + if agent_state.human is None: + raise ValueError(f"'human' not specified in AgentState (required)") + if agent_state.preset is None: + raise ValueError(f"'preset' not specified in AgentState (required)") + if not (agent_state.state == {} or agent_state.state is None): + raise ValueError(f"'state' must be uninitialized (empty)") + + preset_name = agent_state.preset + persona_file = agent_state.persona + human_file = agent_state.human + model = agent_state.llm_config.model from memgpt.agent import Agent from memgpt.utils import printd @@ -40,22 +57,26 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}") preset_function_set[f_name] = available_functions[f_name] assert len(preset_function_set_names) == len(preset_function_set) + preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()] printd(f"Available functions:\n", list(preset_function_set.keys())) - # preset_function_set = {f_name: f_dict for f_name, f_dict in available_functions.items() if f_name in preset_function_set_names} - # printd(f"Available functions:\n", [f_name for f_name, f_dict in preset_function_set.items()]) - # Make sure that every function the preset wanted is inside the available functions - # assert len(preset_function_set_names) == len(preset_function_set) + # Override the following in the AgentState: + # persona: str # the current persona text + # human: str # the current human text + # system: str, # system prompt (not required if initializing with a preset) + # functions: dict, # schema definitions ONLY (function code linked at runtime) + # messages: List[dict], # in-context messages + agent_state.state = { + "persona": get_persona_text(persona_file), + "human": get_human_text(human_file), + "system": gpt_system.get_system_text(preset_system_prompt), + "functions": preset_function_set_schemas, + "messages": None, + } return Agent( - config=agent_config, - model=model, - system=gpt_system.get_system_text(preset_system_prompt), - functions=preset_function_set, + agent_state=agent_state, interface=interface, - persistence_manager=persistence_manager, - persona_notes=persona, - human_notes=human, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, ) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 821babb5..8fcdc6b5 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -12,11 +12,14 @@ from memgpt.agent import Agent import memgpt.system as system import memgpt.constants as constants from memgpt.cli.cli import attach -from memgpt.connectors.storage import StorageConnector + +# from memgpt.agent_store.storage import StorageConnector +from memgpt.metadata import MetadataStore import memgpt.presets.presets as presets import memgpt.utils as utils import memgpt.server.utils as server_utils from memgpt.persistence_manager import PersistenceManager, LocalStateManager +from memgpt.data_types import Source, Passage, Document, User, AgentState # TODO use custom interface from memgpt.interface import CLIInterface # for printing to terminal @@ -130,7 +133,7 @@ class SyncServer(LockingServer): max_chaining_steps: bool = None, # default_interface_cls: AgentInterface = CLIInterface, default_interface: AgentInterface = CLIInterface(), - default_persistence_manager_cls: PersistenceManager = LocalStateManager, + # default_persistence_manager_cls: PersistenceManager = LocalStateManager, ): """Server process holds in-memory agents that are being run""" @@ -149,15 +152,20 @@ class SyncServer(LockingServer): self.default_interface = default_interface # The default persistence manager that will get assigned to agents ON CREATION - self.default_persistence_manager_cls = default_persistence_manager_cls + # self.default_persistence_manager_cls = default_persistence_manager_cls + + # Initialize the connection to the DB + self.config = MemGPTConfig() + self.ms = MetadataStore(self.config) def save_agents(self): + """Saves all the agents that are in the in-memory object store""" for agent_d in self.active_agents: try: agent_d["agent"].save() logger.info(f"Saved agent {agent_d['agent_id']}") except Exception as e: - logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}") + logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}") def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]: """Get the agent object from the in-memory object store""" @@ -187,18 +195,20 @@ class SyncServer(LockingServer): if interface is None: interface = self.default_interface - # If the agent isn't load it, load it and put it into memory - if AgentConfig.exists(agent_id): - logger.debug(f"(user={user_id}, agent={agent_id}) exists, loading into memory...") - agent_config = AgentConfig.load(agent_id) - with utils.suppress_stdout(): - memgpt_agent = Agent.load_agent(interface=interface, agent_config=agent_config) + try: + agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id) + if not agent_state: + raise ValueError(f"agent_id {agent_id} does not exist") + + # Instantiate an agent object using the state retrieved + memgpt_agent = Agent(agent_state=agent_state, interface=interface) + + # Add the agent to the in-memory store and return its reference self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent) return memgpt_agent - # If the agent doesn't exist, throw an error - else: - raise ValueError(f"agent_id {agent_id} does not exist") + except Exception as e: + logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}") def _get_or_load_agent(self, user_id: str, agent_id: str) -> Agent: """Check if the agent is in-memory, then load""" @@ -408,42 +418,72 @@ class SyncServer(LockingServer): def create_agent( self, user_id: str, - agent_config: Union[dict, AgentConfig], + agent_config: dict, interface: Union[AgentInterface, None] = None, - persistence_manager: Union[PersistenceManager, None] = None, - ) -> str: + # persistence_manager: Union[PersistenceManager, None] = None, + ) -> AgentState: """Create a new agent using a config""" # Initialize the agent based on the provided configuration - if isinstance(agent_config, dict): - agent_config = AgentConfig(**agent_config) + if not isinstance(agent_config, dict): + raise ValueError(f"agent_config must be provided as a dictionary") if interface is None: # interface = self.default_interface_cls() interface = self.default_interface - if persistence_manager is None: - persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config) + # if persistence_manager is None: + # persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config) - # Create agent via preset from config - agent = presets.use_preset( - agent_config.preset, - agent_config, - agent_config.model, - utils.get_persona_text(agent_config.persona), - utils.get_human_text(agent_config.human), - interface, - persistence_manager, + # TODO actually use the user_id that was passed into the server + USER_ID = self.config.anon_clientid + # create user and agent + user = User(id=USER_ID) + user = self.ms.get_user(user_id=USER_ID) + if not user: + user = User(id=USER_ID) + self.ms.create_user(user) + + agent_state = AgentState( + user_id=user.id, + name=agent_config["name"] if "name" in agent_config else utils.create_random_username(), + preset=agent_config["preset"] if "preset" in agent_config else user.default_preset, + # TODO we need to allow passing raw persona/human text via the server request + persona=agent_config["persona"] if "persona" in agent_config else user.default_persona, + human=agent_config["human"] if "human" in agent_config else user.default_human, + llm_config=agent_config["llm_config"] if "llm_config" in agent_config else user.default_llm_config, + embedding_config=agent_config["embedding_config"] if "embedding_config" in agent_config else user.default_embedding_config, ) + agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface) + # TODO where should we handle saving of the AgentState? agent.save() + # try: + # self.ms.create_agent(agent) + # except ValueError: + # agent name under user.id already exists, not OK + # raise logger.info(f"Created new agent from config: {agent}") - return agent.config.name + return agent.config + + def delete_agent( + self, + user_id: str, + agent_id: str, + ): + # Make sure the user owns the agent + # TODO use real user_id + USER_ID = self.config.anon_clientid + agent = self.ms.get_agent(agent_id=agent_id, user_id=USER_ID) + if agent is not None: + self.ms.delete_agent(agent_id=agent_id) def list_agents(self, user_id: str) -> dict: """List all available agents to a user""" - agents_list = utils.list_agent_config_files() - return {"num_agents": len(agents_list), "agent_names": agents_list} + # TODO actually use the user_id that was passed into the server + USER_ID = self.config.anon_clientid + agents_list = self.ms.list_agents(user_id=USER_ID) + return {"num_agents": len(agents_list), "agent_names": [state.name for state in agents_list]} def get_agent_memory(self, user_id: str, agent_id: str) -> dict: """Return the memory of an agent (core memory + non-core statistics)""" diff --git a/memgpt/utils.py b/memgpt/utils.py index 8cb60c0a..0046df6b 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -4,6 +4,7 @@ import json import os import pickle import platform +import random import subprocess import sys import io @@ -29,6 +30,479 @@ from memgpt.openai_backcompat.openai_object import OpenAIObject # DEBUG = True DEBUG = False +ADJECTIVE_BANK = [ + "beautiful", + "gentle", + "angry", + "vivacious", + "grumpy", + "luxurious", + "fierce", + "delicate", + "fluffy", + "radiant", + "elated", + "magnificent", + "sassy", + "ecstatic", + "lustrous", + "gleaming", + "sorrowful", + "majestic", + "proud", + "dynamic", + "energetic", + "mysterious", + "loyal", + "brave", + "decisive", + "frosty", + "cheerful", + "adorable", + "melancholy", + "vibrant", + "elegant", + "gracious", + "inquisitive", + "opulent", + "peaceful", + "rebellious", + "scintillating", + "dazzling", + "whimsical", + "impeccable", + "meticulous", + "resilient", + "charming", + "vivacious", + "creative", + "intuitive", + "compassionate", + "innovative", + "enthusiastic", + "tremendous", + "effervescent", + "tenacious", + "fearless", + "sophisticated", + "witty", + "optimistic", + "exquisite", + "sincere", + "generous", + "kindhearted", + "serene", + "amiable", + "adventurous", + "bountiful", + "courageous", + "diligent", + "exotic", + "grateful", + "harmonious", + "imaginative", + "jubilant", + "keen", + "luminous", + "nurturing", + "outgoing", + "passionate", + "quaint", + "resourceful", + "sturdy", + "tactful", + "unassuming", + "versatile", + "wondrous", + "youthful", + "zealous", + "ardent", + "benevolent", + "capricious", + "dedicated", + "empathetic", + "fabulous", + "gregarious", + "humble", + "intriguing", + "jovial", + "kind", + "lovable", + "mindful", + "noble", + "original", + "pleasant", + "quixotic", + "reliable", + "spirited", + "tranquil", + "unique", + "venerable", + "warmhearted", + "xenodochial", + "yearning", + "zesty", + "amusing", + "blissful", + "calm", + "daring", + "enthusiastic", + "faithful", + "graceful", + "honest", + "incredible", + "joyful", + "kind", + "lovely", + "merry", + "noble", + "optimistic", + "peaceful", + "quirky", + "respectful", + "sweet", + "trustworthy", + "understanding", + "vibrant", + "witty", + "xenial", + "youthful", + "zealous", + "ambitious", + "brilliant", + "careful", + "devoted", + "energetic", + "friendly", + "glorious", + "humorous", + "intelligent", + "jovial", + "knowledgeable", + "loyal", + "modest", + "nice", + "obedient", + "patient", + "quiet", + "resilient", + "selfless", + "tolerant", + "unique", + "versatile", + "warm", + "xerothermic", + "yielding", + "zestful", + "amazing", + "bold", + "charming", + "determined", + "exciting", + "funny", + "happy", + "imaginative", + "jolly", + "keen", + "loving", + "magnificent", + "nifty", + "outstanding", + "polite", + "quick", + "reliable", + "sincere", + "thoughtful", + "unusual", + "valuable", + "wonderful", + "xenodochial", + "zealful", + "admirable", + "bright", + "clever", + "dedicated", + "extraordinary", + "generous", + "hardworking", + "inspiring", + "jubilant", + "kind-hearted", + "lively", + "miraculous", + "neat", + "open-minded", + "passionate", + "remarkable", + "stunning", + "truthful", + "upbeat", + "vivacious", + "welcoming", + "yare", + "zealous", +] + +NOUN_BANK = [ + "lizard", + "firefighter", + "banana", + "castle", + "dolphin", + "elephant", + "forest", + "giraffe", + "harbor", + "iceberg", + "jewelry", + "kangaroo", + "library", + "mountain", + "notebook", + "orchard", + "penguin", + "quilt", + "rainbow", + "squirrel", + "teapot", + "umbrella", + "volcano", + "waterfall", + "xylophone", + "yacht", + "zebra", + "apple", + "butterfly", + "caterpillar", + "dragonfly", + "elephant", + "flamingo", + "gorilla", + "hippopotamus", + "iguana", + "jellyfish", + "koala", + "lemur", + "mongoose", + "nighthawk", + "octopus", + "panda", + "quokka", + "rhinoceros", + "salamander", + "tortoise", + "unicorn", + "vulture", + "walrus", + "xenopus", + "yak", + "zebu", + "asteroid", + "balloon", + "compass", + "dinosaur", + "eagle", + "firefly", + "galaxy", + "hedgehog", + "island", + "jaguar", + "kettle", + "lion", + "mammoth", + "nucleus", + "owl", + "pumpkin", + "quasar", + "reindeer", + "snail", + "tiger", + "universe", + "vampire", + "wombat", + "xerus", + "yellowhammer", + "zeppelin", + "alligator", + "buffalo", + "cactus", + "donkey", + "emerald", + "falcon", + "gazelle", + "hamster", + "icicle", + "jackal", + "kitten", + "leopard", + "mushroom", + "narwhal", + "opossum", + "peacock", + "quail", + "rabbit", + "scorpion", + "toucan", + "urchin", + "viper", + "wolf", + "xray", + "yucca", + "zebu", + "acorn", + "biscuit", + "cupcake", + "daisy", + "eyeglasses", + "frisbee", + "goblin", + "hamburger", + "icicle", + "jackfruit", + "kaleidoscope", + "lighthouse", + "marshmallow", + "nectarine", + "obelisk", + "pancake", + "quicksand", + "raspberry", + "spinach", + "truffle", + "umbrella", + "volleyball", + "walnut", + "xylophonist", + "yogurt", + "zucchini", + "asterisk", + "blackberry", + "chimpanzee", + "dumpling", + "espresso", + "fireplace", + "gnome", + "hedgehog", + "illustration", + "jackhammer", + "kumquat", + "lemongrass", + "mandolin", + "nugget", + "ostrich", + "parakeet", + "quiche", + "racquet", + "seashell", + "tadpole", + "unicorn", + "vaccination", + "wolverine", + "xenophobia", + "yam", + "zeppelin", + "accordion", + "broccoli", + "carousel", + "daffodil", + "eggplant", + "flamingo", + "grapefruit", + "harpsichord", + "impression", + "jackrabbit", + "kitten", + "llama", + "mandarin", + "nachos", + "obelisk", + "papaya", + "quokka", + "rooster", + "sunflower", + "turnip", + "ukulele", + "viper", + "waffle", + "xylograph", + "yeti", + "zephyr", + "abacus", + "blueberry", + "crocodile", + "dandelion", + "echidna", + "fig", + "giraffe", + "hamster", + "iguana", + "jackal", + "kiwi", + "lobster", + "marmot", + "noodle", + "octopus", + "platypus", + "quail", + "raccoon", + "starfish", + "tulip", + "urchin", + "vampire", + "walrus", + "xylophone", + "yak", + "zebra", +] + + +def create_random_username() -> str: + """Generate a random username by combining an adjective and a noun.""" + adjective = random.choice(ADJECTIVE_BANK).capitalize() + noun = random.choice(NOUN_BANK).capitalize() + return adjective + noun + + +def verify_first_message_correctness(response, require_send_message=True, require_monologue=False) -> bool: + """Can be used to enforce that the first message always uses send_message""" + response_message = response.choices[0].message + + # First message should be a call to send_message with a non-empty content + if require_send_message and not response_message.get("function_call"): + printd(f"First message didn't include function call: {response_message}") + return False + + function_call = response_message.get("function_call") + function_name = function_call.get("name") if function_call is not None else "" + if require_send_message and function_name != "send_message" and function_name != "archival_memory_search": + printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}") + return False + + if require_monologue and ( + not response_message.get("content") or response_message["content"] is None or response_message["content"] == "" + ): + printd(f"First message missing internal monologue: {response_message}") + return False + + if response_message.get("content"): + ### Extras + monologue = response_message.get("content") + + def contains_special_characters(s): + special_characters = '(){}[]"' + return any(char in s for char in special_characters) + + if contains_special_characters(monologue): + printd(f"First message internal monologue contained special characters: {response_message}") + return False + # if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower(): + if "functions" in monologue or "send_message" in monologue: + # Sometimes the syntax won't be correct and internal syntax will leak into message.context + printd(f"First message internal monologue contained reserved words: {response_message}") + return False + + return True + def is_valid_url(url): try: @@ -157,6 +631,10 @@ def get_local_time(timezone=None): return time_str.strip() +def format_datetime(dt): + return dt.strftime("%Y-%m-%d %I:%M:%S %p %Z%z") + + def parse_json(string): """Parse JSON string into JSON with both json and demjson""" result = None diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 8e30761b..57f0605b 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -16,7 +16,7 @@ def create_test_agent(): global client client = MemGPT(quickstart="openai") - agent_id = client.create_agent( + agent_state = client.create_agent( agent_config={ # "name": test_agent_id, "persona": constants.DEFAULT_PERSONA, @@ -25,13 +25,14 @@ def create_test_agent(): ) global agent_obj - agent_obj = client.server._get_or_load_agent(user_id="NULL", agent_id=agent_id) + agent_obj = client.server._get_or_load_agent(user_id="NULL", agent_id=agent_state.id) def test_archival(): global agent_obj if agent_obj is None: create_test_agent() + assert agent_obj is not None base_functions.archival_memory_insert(agent_obj, "banana") diff --git a/tests/test_client.py b/tests/test_client.py index 5ef097dc..28416b3a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,34 +1,82 @@ from memgpt import MemGPT from memgpt import constants +from memgpt.data_types import LLMConfig, EmbeddingConfig from .utils import wipe_config -test_agent_id = "test_client_agent" +test_agent_name = "test_client_agent" +test_agent_state = None client = None +test_agent_state_post_message = None + def test_create_agent(): wipe_config() global client client = MemGPT(quickstart="openai") - agent_id = client.create_agent( + global test_agent_state + test_agent_state = client.create_agent( agent_config={ - "name": test_agent_id, + "name": test_agent_name, "persona": constants.DEFAULT_PERSONA, "human": constants.DEFAULT_HUMAN, } ) - assert agent_id is not None - return client, agent_id + assert test_agent_state is not None def test_user_message(): + """Test that we can send a message through the client""" assert client is not None, "Run create_agent test first" - response = client.user_message(agent_id=test_agent_id, message="Hello my name is Test, Client Test") + response = client.user_message(agent_id=test_agent_state.id, message="Hello my name is Test, Client Test") assert response is not None and len(response) > 0 + global test_agent_state_post_message + test_agent_state_post_message = client.server.active_agents[0]["agent"].to_agent_state() + + +def test_save_load(): + """Test that state is being persisted correctly after an /exit + + Create a new agent, and request a message + + Then trigger + """ + assert client is not None, "Run create_agent test first" + assert test_agent_state is not None, "Run create_agent test first" + assert test_agent_state_post_message is not None, "Run test_user_message test first" + + # Create a new client (not thread safe), and load the same agent + # The agent state inside should correspond to the initial state pre-message + client2 = MemGPT(quickstart="openai") + client2_agent_obj = client2.server._get_or_load_agent(user_id="", agent_id=test_agent_state.id) + client2_agent_state = client2_agent_obj.to_agent_state() + + # assert test_agent_state == client2_agent_state, f"{vars(test_agent_state)}\n{vars(client2_agent_state)}" + def check_state_equivalence(state_1, state_2): + assert state_1.keys() == state_2.keys(), f"{state_1.keys()}\n{state_2.keys}" + for k, v1 in state_1.items(): + v2 = state_2[k] + if isinstance(v1, LLMConfig) or isinstance(v1, EmbeddingConfig): + assert vars(v1) == vars(v2), f"{vars(v1)}\n{vars(v2)}" + else: + assert v1 == v2, f"{v1}\n{v2}" + + check_state_equivalence(vars(test_agent_state), vars(client2_agent_state)) + + # Now, write out the save from the original client + # This should persist the test message into the agent state + client.save() + + client3 = MemGPT(quickstart="openai") + client3_agent_obj = client3.server._get_or_load_agent(user_id="", agent_id=test_agent_state.id) + client3_agent_state = client3_agent_obj.to_agent_state() + + check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state)) + if __name__ == "__main__": test_create_agent() diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 48f3f1e4..52df3ffa 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -1,17 +1,16 @@ -# import tempfile -# import asyncio import os - +import uuid import pytest from sqlalchemy.ext.declarative import declarative_base # import memgpt -from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.cli.cli_load import load_directory, load_database, load_webpage from memgpt.cli.cli import attach -from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN -from memgpt.config import AgentConfig, MemGPTConfig +from memgpt.config import MemGPTConfig +from memgpt.metadata import MetadataStore +from memgpt.data_types import User, AgentState, EmbeddingConfig @pytest.fixture(autouse=True) @@ -61,10 +60,42 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented") config.save() + # create metadata store + ms = MetadataStore(config) + + # embedding config + if os.getenv("OPENAI_API_KEY"): + embedding_config = EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + openai_key=os.getenv("OPENAI_API_KEY"), + ) + else: + embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384) + + # create user and agent + user = User(id=uuid.UUID(config.anon_clientid), default_embedding_config=embedding_config) + agent = AgentState( + user_id=user.id, + name="test_agent", + preset=user.default_preset, + persona=user.default_persona, + human=user.default_human, + llm_config=user.default_llm_config, + embedding_config=user.default_embedding_config, + ) + ms.delete_user(user.id) + ms.create_user(user) + ms.create_agent(agent) + user = ms.get_user(user.id) + print("Got user:", user, user.default_embedding_config) + # setup storage connectors print("Creating storage connectors...") - data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES) - passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector) + user_id = user.id + print("User ID", user_id) + passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) # load data name = "test_dataset" @@ -74,23 +105,18 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c # clear out data print("Resetting tables with delete_table...") - data_source_conn.delete_table() passages_conn.delete_table() print("Re-creating tables...") - data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES) - passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector) - assert ( - data_source_conn.size() == 0 - ), f"Expected 0 records, got {data_source_conn.size()}: {[vars(r) for r in data_source_conn.get_all()]}" + passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}" # test: load directory print("Loading directory") - load_directory(name=name, input_dir=None, input_files=[cache_dir], recursive=False) # cache_dir, + load_directory(name=name, input_dir=None, input_files=[cache_dir], recursive=False, user_id=user_id) # cache_dir, # test to see if contained in storage print("Querying table...") - sources = data_source_conn.get_all({"name": name}) + sources = ms.list_sources(user_id=user_id) assert len(sources) == 1, f"Expected 1 source, but got {len(sources)}" assert sources[0].name == name, f"Expected name {name}, but got {sources[0].name}" print("Source", sources) @@ -109,40 +135,33 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c # test: listing sources print("Querying all...") - sources = data_source_conn.get_all() + sources = ms.list_sources(user_id=user_id) print("All sources", [s.name for s in sources]) # test loading into an agent # create agent - agent_config = AgentConfig( - name="memgpt_test_agent", - persona=DEFAULT_PERSONA, - human=DEFAULT_HUMAN, - model=DEFAULT_MEMGPT_MODEL, - ) - agent_config.save() + agent_id = agent.id # create storage connector print("Creating agent archival storage connector...") - conn = StorageConnector.get_storage_connector( - storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config - ) + conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id) print("Deleting agent archival table...") conn.delete_table() - conn = StorageConnector.get_storage_connector( - storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config - ) + conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id) assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}" # attach data print("Attaching data...") - attach(agent=agent_config.name, data_source=name) + attach(agent=agent.name, data_source=name, user_id=user_id) # test to see if contained in storage assert len(passages) == conn.size() assert len(passages) == len(conn.get_all({"data_source": name})) # test: delete source - data_source_conn.delete({"name": name}) passages_conn.delete({"data_source": name}) - assert len(data_source_conn.get_all({"name": name})) == 0 assert len(passages_conn.get_all({"data_source": name})) == 0 + + # cleanup + ms.delete_user(user.id) + ms.delete_agent(agent.id) + ms.delete_source(sources[0].id) diff --git a/tests/test_metadata_store.py b/tests/test_metadata_store.py new file mode 100644 index 00000000..7ac32c43 --- /dev/null +++ b/tests/test_metadata_store.py @@ -0,0 +1,82 @@ +import os +import pytest + +from memgpt.metadata import MetadataStore +from memgpt.config import MemGPTConfig +from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig + + +# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"]) +@pytest.mark.parametrize("storage_connector", ["sqlite"]) +def test_storage(storage_connector): + config = MemGPTConfig() + if storage_connector == "postgres": + if not os.getenv("PGVECTOR_TEST_DB_URL"): + print("Skipping test, missing PG URI") + return + config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + config.archival_storage_type = "postgres" + config.recall_storage_type = "postgres" + if storage_connector == "sqlite": + config.recall_storage_type = "local" + + ms = MetadataStore(config) + + # generate data + user_1 = User(default_llm_config=LLMConfig(model="gpt-4")) + user_2 = User() + agent_1 = AgentState( + user_id=user_1.id, + name="agent_1", + preset=user_1.default_preset, + persona=user_1.default_persona, + human=user_1.default_human, + llm_config=user_1.default_llm_config, + embedding_config=user_1.default_embedding_config, + ) + source_1 = Source(user_id=user_1.id, name="source_1") + + # test creation + ms.create_user(user_1) + ms.create_user(user_2) + ms.create_agent(agent_1) + ms.create_source(source_1) + + # test listing + len(ms.list_agents(user_id=user_1.id)) == 1 + len(ms.list_agents(user_id=user_2.id)) == 0 + len(ms.list_sources(user_id=user_1.id)) == 1 + len(ms.list_sources(user_id=user_2.id)) == 0 + + # test: updating + + # test: update JSON-stored LLMConfig class + print(agent_1.llm_config, user_1.default_llm_config) + llm_config = ms.get_agent(agent_1.id).llm_config + assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}" + assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}" + llm_config.model = "gpt3.5-turbo" + agent_1.llm_config = llm_config + ms.update_agent(agent_1) + assert ms.get_agent(agent_1.id).llm_config.model == "gpt3.5-turbo", f"Updated LLMConfig to {ms.get_agent(agent_1.id).llm_config.model}" + + # test attaching sources + len(ms.list_attached_sources(agent_id=agent_1.id)) == 0 + ms.attach_source(user_1.id, agent_1.id, source_1.id) + len(ms.list_attached_sources(agent_id=agent_1.id)) == 1 + + # test: detaching sources + ms.detach_source(agent_1.id, source_1.id) + len(ms.list_attached_sources(agent_id=agent_1.id)) == 0 + + # test getting + ms.get_user(user_1.id) + ms.get_agent(agent_1.id) + ms.get_source(source_1.id) + + # text deletion + ms.delete_user(user_1.id) + ms.delete_user(user_2.id) + ms.delete_agent(agent_1.id) + ms.delete_source(source_1.id) diff --git a/tests/test_server.py b/tests/test_server.py index 66732779..e97f21b3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,41 +2,45 @@ import memgpt.utils as utils utils.DEBUG = True from memgpt.server.server import SyncServer +from .utils import wipe_config, wipe_memgpt_home def test_server(): + wipe_memgpt_home() + user_id = "NULL" - agent_id = "agent_26" server = SyncServer() try: - server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") - except ValueError as e: + server.user_message(user_id=user_id, agent_id="agent no exist", message="Hello?") + raise Exception("user_message call should have failed") + except (KeyError, ValueError) as e: + # Error is expected print(e) except: raise + agent_state = server.create_agent( + user_id=user_id, + agent_config=dict( + preset="memgpt_chat", + human="cs_phd", + persona="sam_pov", + ), + ) + print(f"Created agent\n{agent_state}") + try: - server.user_message(user_id=user_id, agent_id=agent_id, message="/memory") + server.user_message(user_id=user_id, agent_id=agent_state.id, message="/memory") + raise Exception("user_message call should have failed") except ValueError as e: + # Error is expected print(e) except: raise - try: - print(server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")) - except ValueError as e: - print(e) - except: - raise - - try: - server.user_message(user_id=user_id, agent_id="agent no-exist", message="Hello?") - except ValueError as e: - print(e) - except: - raise + print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory")) if __name__ == "__main__": diff --git a/tests/test_storage.py b/tests/test_storage.py index 0aaae5af..3fab73e2 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,33 +1,29 @@ import os +from sqlalchemy.ext.declarative import declarative_base import uuid -import subprocess -import sys import pytest -# subprocess.check_call( -# [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"] -# ) # , "psycopg_binary"]) # "psycopg", "libpq-dev"]) -# -# subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"]) -from memgpt.connectors.storage import StorageConnector, TableType +from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.embeddings import embedding_model -from memgpt.data_types import Message, Passage -from memgpt.config import MemGPTConfig, AgentConfig -from memgpt.utils import get_local_time -from memgpt.connectors.storage import StorageConnector, TableType -from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN +from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig +from memgpt.config import MemGPTConfig +from memgpt.agent_store.storage import StorageConnector, TableType +from memgpt.metadata import MetadataStore +from memgpt.data_types import User -import argparse from datetime import datetime, timedelta + # Note: the database will filter out rows that do not correspond to agent1 and test_user by default. texts = ["This is a test passage", "This is another test passage", "Cinderella wept"] start_date = datetime(2009, 10, 5, 18, 00) dates = [start_date, start_date - timedelta(weeks=1), start_date + timedelta(weeks=1)] roles = ["user", "agent", "agent"] -agent_ids = ["agent1", "agent2", "agent1"] +agent_1_id = uuid.uuid4() +agent_2_id = uuid.uuid4() +agent_ids = [agent_1_id, agent_2_id, agent_1_id] ids = [uuid.uuid4(), uuid.uuid4(), uuid.uuid4()] -user_id = "test_user" +user_id = uuid.uuid4() # Data generation functions: Passages @@ -52,17 +48,44 @@ def generate_messages(embed_model): if embed_model: embedding = embed_model.get_text_embedding(text) messages.append( - Message(user_id=user_id, text=text, agent_id=agent_id, role=role, created_at=date, id=id, model="gpt4", embedding=embedding) + Message(user_id=user_id, text=text, agent_id=agent_id, role=role, created_at=date, id=id, model="gpt-4", embedding=embedding) ) print(messages[-1].text) return messages +@pytest.fixture(autouse=True) +def clear_dynamically_created_models(): + """Wipe globals for SQLAlchemy""" + yield + for key in list(globals().keys()): + if key.endswith("Model"): + del globals()[key] + + +@pytest.fixture(autouse=True) +def recreate_declarative_base(): + """Recreate the declarative base before each test""" + global Base + Base = declarative_base() + yield + Base.metadata.clear() + + @pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqlite"]) +# @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"]) +# @pytest.mark.parametrize("storage_connector", ["postgres"]) @pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY]) -def test_storage(storage_connector, table_type): +def test_storage(storage_connector, table_type, clear_dynamically_created_models, recreate_declarative_base): # setup memgpt config # TODO: set env for different config path + + # hacky way to cleanup globals that scruw up tests + # for table_name in ['Message']: + # if 'Message' in globals(): + # print("Removing messages", globals()['Message']) + # del globals()['Message'] + config = MemGPTConfig() if storage_connector == "postgres": if not os.getenv("PGVECTOR_TEST_DB_URL"): @@ -91,39 +114,43 @@ def test_storage(storage_connector, table_type): if table_type == TableType.ARCHIVAL_MEMORY: print("Skipping test, sqlite only supported for recall memory") return - config.recall_storage_type = "local" + config.recall_storage_type = "sqlite" # get embedding model embed_model = None if os.getenv("OPENAI_API_KEY"): - config.embedding_endpoint_type = "openai" - config.embedding_endpoint = "https://api.openai.com/v1" - config.embedding_dim = 1536 - config.openai_key = os.getenv("OPENAI_API_KEY") + embedding_config = EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + openai_key=os.getenv("OPENAI_API_KEY"), + ) else: - config.embedding_endpoint_type = "local" - config.embedding_endpoint = None - config.embedding_dim = 384 - config.save() - embed_model = embedding_model() + embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384) + embed_model = embedding_model(embedding_config) - # create agent - agent_config = AgentConfig( - name="agent1", - persona=DEFAULT_PERSONA, - human=DEFAULT_HUMAN, - model=DEFAULT_MEMGPT_MODEL, + # create user + ms = MetadataStore(config) + ms.delete_user(user_id) + user = User(id=user_id, default_embedding_config=embedding_config) + agent = AgentState( + user_id=user_id, + name="agent_1", + id=agent_1_id, + preset=user.default_preset, + persona=user.default_persona, + human=user.default_human, + llm_config=user.default_llm_config, + embedding_config=user.default_embedding_config, ) + ms.create_user(user) + ms.create_agent(agent) # create storage connector - conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config) + conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id) # conn.client.delete_collection(conn.collection.name) # clear out data conn.delete_table() - conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config) - - # override filters - conn.user_id = user_id - conn.filters = {"user_id": user_id, "agent_id": "agent1"} + conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id) # generate data if table_type == TableType.ARCHIVAL_MEMORY: @@ -169,7 +196,7 @@ def test_storage(storage_connector, table_type): # test: size assert conn.size() == 2, f"Expected 2 records, got {conn.size()}" - assert conn.size(filters={"agent_id": "agent1"}) == 2, f"Expected 2 records, got {conn.size(filters={'agent_id', 'agent1'})}" + assert conn.size(filters={"agent_id": agent.id}) == 2, f"Expected 2 records, got {conn.size(filters={'agent_id', agent.id})}" if table_type == TableType.RECALL_MEMORY: assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}" @@ -202,3 +229,6 @@ def test_storage(storage_connector, table_type): # test: delete conn.delete({"id": ids[0]}) assert conn.size() == 1, f"Expected 2 records, got {conn.size()}" + + # cleanup + ms.delete_user(user_id) diff --git a/tests/test_websocket_interface.py b/tests/test_websocket_interface.py index b84b53cc..52704263 100644 --- a/tests/test_websocket_interface.py +++ b/tests/test_websocket_interface.py @@ -18,7 +18,7 @@ from memgpt.persistence_manager import LocalStateManager # persistence_manager = InMemoryStateManager() # # Create an agent and hook it up to the WebSocket interface -# memgpt_agent = presets.use_preset( +# memgpt_agent = presets.create_agent_from_preset( # presets.DEFAULT_PRESET, # None, # no agent config to provide # "gpt-4-1106-preview", @@ -42,6 +42,7 @@ async def test_dummy(): assert True +@pytest.mark.skip(reason="websockets is temporarily unsupported in 0.2.12") @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key") @pytest.mark.asyncio async def test_websockets(): @@ -76,7 +77,7 @@ async def test_websockets(): ) persistence_manager = LocalStateManager(agent_config=agent_config) - memgpt_agent = presets.use_preset( + memgpt_agent = presets.create_agent_from_preset( agent_config.preset, agent_config, agent_config.model, diff --git a/tests/test_websocket_server.py b/tests/test_websocket_server.py index 4a530e4d..58423e79 100644 --- a/tests/test_websocket_server.py +++ b/tests/test_websocket_server.py @@ -14,6 +14,7 @@ async def test_dummy(): assert True +@pytest.mark.skip(reason="websockets is temporarily unsupported in 0.2.12") @pytest.mark.asyncio async def test_websocket_server(): # host = "127.0.0.1" diff --git a/tests/utils.py b/tests/utils.py index c08c0088..37f8bb87 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,4 @@ +import datetime import os from memgpt.config import MemGPTConfig @@ -16,6 +17,22 @@ def wipe_config(): os.remove(config_path) +def wipe_memgpt_home(): + """Wipes ~/.memgpt (moves to a backup), and initializes a new ~/.memgpt dir""" + + # Get the current timestamp in a readable format (e.g., YYYYMMDD_HHMMSS) + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + # Construct the new backup directory name with the timestamp + backup_dir = f"~/.memgpt_test_backup_{timestamp}" + + # Use os.system to execute the 'mv' command + os.system(f"mv ~/.memgpt {backup_dir}") + + # Setup the initial directory + MemGPTConfig.create_config_dir() + + def configure_memgpt_localllm(): import pexpect