From 3b9915171168b35845c51b478dd0af3fe269f7db Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 22 Feb 2024 11:16:01 -0800 Subject: [PATCH] feat: Create `RESTClient` and `Admin` client for interacting with server from python (#1033) --- .github/workflows/tests.yml | 17 +- README.md | 91 +------ docs/python_client.md | 74 ++---- examples/memgpt_client.py | 47 ++++ memgpt/__init__.py | 3 +- memgpt/benchmark/benchmark.py | 4 +- memgpt/cli/cli_load.py | 2 + memgpt/client/admin.py | 26 ++ memgpt/client/client.py | 230 +++++++++++++----- memgpt/config.py | 2 +- memgpt/metadata.py | 2 +- memgpt/models/pydantic_models.py | 33 +++ memgpt/server/rest_api/admin/users.py | 8 +- memgpt/server/rest_api/agents/config.py | 12 +- memgpt/server/rest_api/agents/index.py | 27 +- .../rest_api/openai_assistants/assistants.py | 3 - memgpt/server/server.py | 35 ++- tests/test_agent_function_update.py | 18 +- tests/test_base_functions.py | 17 +- tests/test_cli.py | 8 +- tests/test_client.py | 163 ++++++------- tests/test_different_embedding_size.py | 29 +-- tests/test_load_archival.py | 7 +- tests/test_migrate.py | 1 - tests/test_openai_assistant_api.py | 92 +++---- tests/test_persistence.py | 52 ++++ tests/test_server.py | 6 +- tests/test_summarize.py | 21 +- tests/utils.py | 12 + 29 files changed, 623 insertions(+), 419 deletions(-) create mode 100644 examples/memgpt_client.py create mode 100644 memgpt/client/admin.py create mode 100644 memgpt/models/pydantic_models.py create mode 100644 tests/test_persistence.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0dbf495d..4ddc0d4a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,10 +28,19 @@ jobs: poetry-version: "1.7.1" install-args: "--all-extras" + - name: Run server tests + env: + PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + MEMGPT_SERVER_PASS: test_server_token + run: | + poetry run pytest -s -vv tests/test_server.py + - name: Run tests with pytest env: PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + MEMGPT_SERVER_PASS: test_server_token run: | poetry run pytest -s -vv -k "not test_storage and not test_server and not test_openai_client" tests @@ -39,12 +48,6 @@ jobs: env: PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + MEMGPT_SERVER_PASS: test_server_token run: | poetry run pytest -s -vv tests/test_storage.py - - - name: Run server tests - env: - PGVECTOR_TEST_DB_URL: postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run: | - poetry run pytest -s -vv tests/test_server.py diff --git a/README.md b/README.md index e0728e27..85f0cfd1 100644 --- a/README.md +++ b/README.md @@ -129,91 +129,24 @@ poetry install ## Python integration (for developers) -The fastest way to integrate MemGPT with your own Python projects is through the `MemGPT` client class: +The fastest way to integrate MemGPT with your own Python projects is through the MemGPT client: ```python -from memgpt import MemGPT +from memgpt import create_client -# Create a MemGPT client object (sets up the persistent state) -client = MemGPT( - quickstart="openai", - config={ - "openai_api_key": "YOUR_API_KEY" - } +# Connect to the server as a user +client = create_client() + +# Create an agent +agent_info = client.create_agent( + name="my_agent", + persona="You are a friendly agent.", + human="Bob is a friendly human." ) -# You can set many more parameters, this is just a basic example -agent_state = client.create_agent( - agent_config={ - "persona": "sam_pov", - "human": "cs_phd", - } -) - -# Now that we have an agent_name identifier, we can send it a message! -# The response will have data from the MemGPT agent -my_message = "Hi MemGPT! How's it going?" -response = client.user_message(agent_id=agent_state.id, message=my_message) +# Send a message to the agent +messages = client.user_message(agent_id=agent_info.id, message="Hello, agent!") ``` -
- - More in-depth example of using MemGPT Client - - -```python -from memgpt.config import AgentConfig -from memgpt import MemGPT -from memgpt.cli.cli import QuickstartChoice - - -client = MemGPT( - # When auto_save is 'True' then the agent(s) will be saved after every - # user message. This may have performance implications, so you - # can otherwise choose when to save explicitly using client.save(). - auto_save=True, - - # Quickstart will automatically configure MemGPT (without having to run `memgpt configure` - # If you choose 'openai' then you must set the api key (env or in config) - quickstart=QuickstartChoice.memgpt_hosted, - - # Allows you to override default config generated by quickstart or `memgpt configure` - config={} -) - -# Create an AgentConfig with default persona and human txt -# In this case, assume we wrote a custom persona file "my_persona.txt", located at ~/.memgpt/personas/my_persona.txt -# Same for a custom user file "my_user.txt", located at ~/.memgpt/humans/my_user.txt -agent_config = AgentConfig( - name="CustomAgent", - persona="my_persona", - human="my_user", -) - -# Create the agent according to AgentConfig we set up. If an agent with -# the same name already exists it will simply return, unless you set -# throw_if_exists to 'True' -agent_id = client.create_agent(agent_config=agent_config) - -# Create a helper that sends a message and prints the assistant response only -def send_message(message: str): - """ - sends a message and prints the assistant output only. - :param message: the message to send - """ - response = client.user_message(agent_id=agent_id, message=message) - for r in response: - # Can also handle other types "function_call", "function_return", "function_message" - if "assistant_message" in r: - print("ASSISTANT:", r["assistant_message"]) - elif "thoughts" in r: - print("THOUGHTS:", r["internal_monologue"]) - -# Send a message and see the response -send_message("Please introduce yourself and tell me about your abilities!") -``` - -
- ## What open LLMs work well with MemGPT? When using MemGPT with open LLMs (such as those downloaded from HuggingFace), the performance of MemGPT will be highly dependent on the LLM's function calling ability. diff --git a/docs/python_client.md b/docs/python_client.md index 2e535366..bf79b3d2 100644 --- a/docs/python_client.md +++ b/docs/python_client.md @@ -4,72 +4,42 @@ excerpt: Developing using the MemGPT Python client category: 6580dab16cade8003f996d17 --- -The fastest way to integrate MemGPT with your own Python projects is through the `MemGPT` [client class](https://github.com/cpacker/MemGPT/blob/main/memgpt/client/client.py): +The fastest way to integrate MemGPT with your own Python projects is through the [client class](https://github.com/cpacker/MemGPT/blob/main/memgpt/client/client.py): ```python -from memgpt import MemGPT +from memgpt import create_client -# Create a MemGPT client object (sets up the persistent state) -client = MemGPT( - quickstart="openai", - config={ - "openai_api_key": "YOUR_API_KEY" - } +# Connect to the server as a user +client = create_client() + +# Create an agent +agent_info = client.create_agent( + name="my_agent", + persona="You are a friendly agent.", + human="Bob is a friendly human." ) -# You can set many more parameters, this is just a basic example -agent_state = client.create_agent( - agent_config={ - "persona": "sam_pov", - "human": "cs_phd", - } -) - -# Now that we have an agent_name identifier, we can send it a message! -# The response will have data from the MemGPT agent -my_message = "Hi MemGPT! How's it going?" -response = client.user_message(agent_id=agent_state.id, message=my_message) +# Send a message to the agent +messages = client.user_message(agent_id=agent_info.id, message="Hello, agent!") ``` ## More in-depth example of using the MemGPT Python client ```python -from memgpt.config import AgentConfig -from memgpt import MemGPT -from memgpt import constants -from memgpt.cli.cli import QuickstartChoice +from memgpt import create_client +# Connect to the server as a user +client = create_client() -client = MemGPT( - # When auto_save is 'True' then the agent(s) will be saved after every - # user message. This may have performance implications, so you - # can otherwise choose when to save explicitly using client.save(). - auto_save=True, - - # Quickstart will automatically configure MemGPT (without having to run `memgpt configure` - # If you choose 'openai' then you must set the api key (env or in config) - quickstart=QuickstartChoice.memgpt_hosted, - - # Allows you to override default config generated by quickstart or `memgpt configure` - config={} +# Create an agent +agent_info = client.create_agent( + name="my_agent", + persona="You are a friendly agent.", + human="Bob is a friendly human." ) -# Create an AgentConfig with default persona and human txt -# In this case, assume we wrote a custom persona file "my_persona.txt", located at ~/.memgpt/personas/my_persona.txt -# Same for a custom user file "my_user.txt", located at ~/.memgpt/humans/my_user.txt -agent_config = AgentConfig( - name="CustomAgent", - persona="my_persona", - human="my_user", - preset="memgpt_chat", - model="gpt-4", -) - -# Create the agent according to AgentConfig we set up. If an agent with -# the same name already exists it will simply return, unless you set -# throw_if_exists to 'True' -agent_state = client.create_agent(agent_config=agent_config) - +# Send a message to the agent +messages = client.user_message(agent_id=agent_info.id, message="Hello, agent!") # Create a helper that sends a message and prints the assistant response only def send_message(message: str): """ diff --git a/examples/memgpt_client.py b/examples/memgpt_client.py new file mode 100644 index 00000000..75d35f59 --- /dev/null +++ b/examples/memgpt_client.py @@ -0,0 +1,47 @@ +from memgpt import create_client, Admin +from memgpt.constants import DEFAULT_PRESET, DEFAULT_HUMAN, DEFAULT_PERSONA + + +""" +Make sure you run the MemGPT server before running this example. +``` +export MEMGPT_SERVER_PASS=your_token +memgpt server +``` +""" + + +def main(): + # Create an admin client + admin = Admin(base_url="http://localhost:8283", token="your_token") + + # Create a user + token + user_id, token = admin.create_user() + print(f"Created user: {user_id} with token: {token}") + + # Connect to the server as a user + client = create_client(base_url="http://localhost:8283", token=token) + + # Create an agent + agent_info = client.create_agent(name="my_agent", preset=DEFAULT_PRESET, persona=DEFAULT_PERSONA, human=DEFAULT_HUMAN) + print(f"Created agent: {agent_info.name} with ID {str(agent_info.id)}") + + # Send a message to the agent + messages = client.user_message(agent_id=agent_info.id, message="Hello, agent!") + print(f"Recieved response: {messages}") + + # TODO: get agent memory + + # TODO: Update agent persona + + # Delete agent + client.delete_agent(agent_id=agent_info.id) + print(f"Deleted agent: {agent_info.name} with ID {str(agent_info.id)}") + + # Delete user + admin.delete_user(user_id=user_id) + print(f"Deleted user: {user_id} with token: {token}") + + +if __name__ == "__main__": + main() diff --git a/memgpt/__init__.py b/memgpt/__init__.py index f70d69aa..b1243926 100644 --- a/memgpt/__init__.py +++ b/memgpt/__init__.py @@ -1,3 +1,4 @@ __version__ = "0.3.3" -from memgpt.client.client import Client as MemGPT +from memgpt.client.client import create_client +from memgpt.client.admin import Admin diff --git a/memgpt/benchmark/benchmark.py b/memgpt/benchmark/benchmark.py index 0a82b1d4..ed7d0714 100644 --- a/memgpt/benchmark/benchmark.py +++ b/memgpt/benchmark/benchmark.py @@ -5,7 +5,7 @@ import typer import time from typing import Annotated, Optional -from memgpt import MemGPT +from memgpt import create_client from memgpt.config import MemGPTConfig # from memgpt.agent import Agent @@ -42,7 +42,7 @@ def bench( print_messages: Annotated[bool, typer.Option("--messages", help="Print functions calls and messages from the agent.")] = False, n_tries: Annotated[int, typer.Option("--n-tries", help="Number of benchmark tries to perform for each function.")] = TRIES, ): - client = MemGPT() + client = create_client() print(f"\nDepending on your hardware, this may take up to 30 minutes. This will also create {n_tries * len(PROMPTS)} new agents.\n") config = MemGPTConfig.load() print(f"version = {config.memgpt_version}") diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 2c44fe11..a8e09ffe 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -99,6 +99,8 @@ def load_directory( ms = MetadataStore(config) source = Source(name=name, user_id=user_id) ms.create_source(source) + print("created source", name, str(user_id)) + print("listing source", user_id, ms.list_sources(user_id=user_id)) passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) # TODO: also get document store diff --git a/memgpt/client/admin.py b/memgpt/client/admin.py new file mode 100644 index 00000000..a47c956b --- /dev/null +++ b/memgpt/client/admin.py @@ -0,0 +1,26 @@ +from typing import Optional +import requests + + +class Admin: + """ + Admin client allows admin-level operations on the MemGPT server. + - Creating users + - Generating user keys + """ + + def __init__(self, base_url: str, token: str): + self.base_url = base_url + self.token = token + self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"} + + def create_user(self, user_id: Optional[str] = None): + payload = {"user_id": str(user_id) if user_id else None} + response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=payload) + response_json = response.json() + print(response_json) + return response_json["user_id"], response_json["api_key"] + + def delete_user(self, user_id: str): + response = requests.delete(f"{self.base_url}/admin/users/{user_id}", headers=self.headers) + return response.json() diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 0e315e7e..0a923aa8 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -1,8 +1,10 @@ import os +import datetime +import requests import uuid from typing import Dict, List, Union, Optional, Tuple -from memgpt.data_types import AgentState, User, Preset +from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig from memgpt.cli.cli import QuickstartChoice from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice from memgpt.config import MemGPTConfig @@ -11,13 +13,157 @@ from memgpt.server.server import SyncServer from memgpt.metadata import MetadataStore -class Client(object): +def create_client(base_url: Optional[str] = None, token: Optional[str] = None): + if base_url is None: + return LocalClient() + else: + return RESTClient(base_url, token) + + +class AbstractClient(object): def __init__( self, - user_id: str = None, auto_save: bool = False, - quickstart: Union[QuickstartChoice, str, None] = None, - config: Union[Dict, MemGPTConfig] = None, # not the same thing as AgentConfig + debug: bool = False, + ): + self.auto_save = auto_save + self.debug = debug + + def list_agents(self): + """List all agents associated with a given user.""" + raise NotImplementedError + + def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: + """Check if an agent with the specified ID or name exists.""" + raise NotImplementedError + + def create_agent( + self, + name: Optional[str] = None, + preset: Optional[str] = None, + persona: Optional[str] = None, + human: Optional[str] = None, + embedding_config: Optional[EmbeddingConfig] = None, + llm_config: Optional[LLMConfig] = None, + ) -> AgentState: + """Create a new agent with the specified configuration.""" + raise NotImplementedError + + def create_preset(self, preset: Preset): + raise NotImplementedError + + def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> AgentState: + raise NotImplementedError + + def get_agent_memory(self, agent_id: str) -> Dict: + raise NotImplementedError + + def update_agent_core_memory(self, agent_id: str, human: Optional[str] = None, persona: Optional[str] = None) -> Dict: + raise NotImplementedError + + def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]: + raise NotImplementedError + + def run_command(self, agent_id: str, command: str) -> Union[str, None]: + raise NotImplementedError + + def save(self): + raise NotImplementedError + + +class RESTClient(AbstractClient): + def __init__( + self, + base_url: str, + token: str, + debug: bool = False, + ): + super().__init__(debug=debug) + self.base_url = base_url + self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"} + + def list_agents(self): + response = requests.get(f"{self.base_url}/agents", headers=self.headers) + print(response.text) + + def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: + response = requests.get(f"{self.base_url}/agents/config?agent_id={str(agent_id)}", headers=self.headers) + print(response.text) + + def create_agent( + self, + name: Optional[str] = None, + preset: Optional[str] = None, + persona: Optional[str] = None, + human: Optional[str] = None, + embedding_config: Optional[EmbeddingConfig] = None, + llm_config: Optional[LLMConfig] = None, + ) -> AgentState: + if embedding_config or llm_config: + raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API") + payload = { + "config": { + "name": name, + "preset": preset, + "persona": persona, + "human": human, + } + } + response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers) + response_json = response.json() + llm_config = LLMConfig(**response_json["agent_state"]["llm_config"]) + embedding_config = EmbeddingConfig(**response_json["agent_state"]["embedding_config"]) + agent_state = AgentState( + id=uuid.UUID(response_json["agent_state"]["id"]), + name=response_json["agent_state"]["name"], + user_id=uuid.UUID(response_json["agent_state"]["user_id"]), + preset=response_json["agent_state"]["preset"], + persona=response_json["agent_state"]["persona"], + human=response_json["agent_state"]["human"], + llm_config=llm_config, + embedding_config=embedding_config, + state=response_json["agent_state"]["state"], + # load datetime from timestampe + created_at=datetime.datetime.fromtimestamp(response_json["agent_state"]["created_at"]), + ) + return agent_state + + def delete_agent(self, agent_id: str): + response = requests.delete(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers) + return agent_id + + def create_preset(self, preset: Preset): + raise NotImplementedError + + def get_agent_config(self, agent_id: str) -> AgentState: + raise NotImplementedError + + def get_agent_memory(self, agent_id: str) -> Dict: + raise NotImplementedError + + def update_agent_core_memory(self, agent_id: str, new_memory_contents: Dict) -> Dict: + raise NotImplementedError + + def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]: + # TODO: support role? what is return_token_count? + payload = {"agent_id": str(agent_id), "message": message} + response = requests.post(f"{self.base_url}/api/agents/message", json=payload, headers=self.headers) + response_json = response.json() + print(response_json) + return response_json + + def run_command(self, agent_id: str, command: str) -> Union[str, None]: + raise NotImplementedError + + def save(self): + raise NotImplementedError + + +class LocalClient(AbstractClient): + def __init__( + self, + auto_save: bool = False, + user_id: Optional[str] = None, debug: bool = False, ): """ @@ -28,47 +174,13 @@ class Client(object): :param debug: indicates whether to display debug messages. """ self.auto_save = auto_save - # make sure everything is set up properly - # TODO: remove this eventually? for multi-user, we can't have a shared config directory - MemGPTConfig.create_config_dir() - # If this is the first ever start, do basic initialization - if not MemGPTConfig.exists() and config is None and quickstart is None: - # Default to openai - print("Detecting uninitialized MemGPT, defaulting to quickstart == openai") - quickstart = "openai" - - if quickstart: - # api key passed in config has priority over env var - if isinstance(config, dict) and "openai_api_key" in config: - openai_key = config["openai_api_key"] - else: - openai_key = os.environ.get("OPENAI_API_KEY", None) - - # throw an error if we can't resolve the key - if openai_key: - os.environ["OPENAI_API_KEY"] = openai_key - elif quickstart == QuickstartChoice.openai or quickstart == "openai": - raise ValueError("Please set OPENAI_API_KEY or pass 'openai_api_key' in config dict") - - if isinstance(quickstart, str): - quickstart = str_to_quickstart_choice(quickstart) - quickstart_func(backend=quickstart, debug=debug) - - if config is not None: - set_config_with_dict(config) - - # determine user_id + # determine user_id (pulled from local config) config = MemGPTConfig.load() - if user_id is None: - # the default user_id - self.user_id = uuid.UUID(config.anon_clientid) - elif isinstance(user_id, str): + if user_id: self.user_id = uuid.UUID(user_id) - elif isinstance(user_id, uuid.UUID): - self.user_id = user_id else: - raise TypeError(user_id) + self.user_id = uuid.UUID(config.anon_clientid) # create user if does not exist ms = MetadataStore(config) @@ -104,25 +216,33 @@ class Client(object): def create_agent( self, - agent_config: dict, + name: Optional[str] = None, + preset: Optional[str] = None, + persona: Optional[str] = None, + human: Optional[str] = None, + embedding_config: Optional[EmbeddingConfig] = None, + llm_config: Optional[LLMConfig] = None, ) -> AgentState: - if isinstance(agent_config, dict): - agent_name = agent_config.get("name") - else: - raise TypeError(f"agent_config must be of type dict") - - if "name" in agent_config and self.agent_exists(agent_name=agent_config["name"]): - raise ValueError(f"Agent with name {agent_config['name']} already exists (user_id={self.user_id})") + if name and self.agent_exists(agent_name=name): + raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})") self.interface.clear() - agent_state = self.server.create_agent(user_id=self.user_id, agent_config=agent_config) + agent_state = self.server.create_agent( + user_id=self.user_id, + name=name, + preset=preset, + persona=persona, + human=human, + embedding_config=embedding_config, + llm_config=llm_config, + ) return agent_state def create_preset(self, preset: Preset): preset = self.server.create_preset(preset=preset) return preset - def get_agent_config(self, agent_id: str) -> Dict: + def get_agent_config(self, agent_id: str) -> AgentState: self.interface.clear() return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id) @@ -134,13 +254,11 @@ class Client(object): self.interface.clear() return self.server.update_agent_core_memory(user_id=self.user_id, agent_id=agent_id, new_memory_contents=new_memory_contents) - def user_message(self, agent_id: str, message: str, return_token_count: bool = False) -> Union[List[Dict], Tuple[List[Dict], int]]: + def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]: self.interface.clear() - tokens_accumulated = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message) + self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message) if self.auto_save: self.save() - if return_token_count: - return self.interface.to_list(), tokens_accumulated else: return self.interface.to_list() diff --git a/memgpt/config.py b/memgpt/config.py index c259fbb4..9886d90f 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -38,7 +38,7 @@ def set_field(config, section, field, value): @dataclass class MemGPTConfig: - config_path: str = os.path.join(MEMGPT_DIR, "config") + config_path: str = os.getenv("MEMGPT_CONFIG_PATH") if os.getenv("MEMGPT_CONFIG_PATH") else os.path.join(MEMGPT_DIR, "config") anon_clientid: str = None # preset diff --git a/memgpt/metadata.py b/memgpt/metadata.py index cd609d88..26146018 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -383,7 +383,7 @@ class MetadataStore: with self.session_maker() as session: if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0: if not exists_ok: - raise ValueError(f"Source with name {source.name} already exists") + raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}") else: session.update(SourceModel(**vars(source))) else: diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py new file mode 100644 index 00000000..f177cc95 --- /dev/null +++ b/memgpt/models/pydantic_models.py @@ -0,0 +1,33 @@ +from typing import List, Union, Optional, Dict, Literal +from enum import Enum +from pydantic import BaseModel, Field, Json +import uuid + + +class LLMConfigModel(BaseModel): + model: Optional[str] = "gpt-4" + model_endpoint_type: Optional[str] = "openai" + model_endpoint: Optional[str] = "https://api.openai.com/v1" + model_wrapper: Optional[str] = None + context_window: Optional[int] = None + + +class EmbeddingConfigModel(BaseModel): + embedding_endpoint_type: Optional[str] = "openai" + embedding_endpoint: Optional[str] = "https://api.openai.com/v1" + embedding_model: Optional[str] = "text-embedding-ada-002" + embedding_dim: Optional[int] = 1536 + embedding_chunk_size: Optional[int] = 300 + + +class AgentStateModel(BaseModel): + id: uuid.UUID = Field(..., description="The unique identifier of the agent.") + name: str = Field(..., description="The name of the agent.") + user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.") + preset: str = Field(..., description="The preset used by the agent.") + persona: str = Field(..., description="The persona used by the agent.") + human: str = Field(..., description="The human used by the agent.") + llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.") + embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.") + state: Optional[Dict] = Field(None, description="The state of the agent.") + created_at: int = Field(..., description="The unix timestamp of when the agent was created.") diff --git a/memgpt/server/rest_api/admin/users.py b/memgpt/server/rest_api/admin/users.py index f2981a02..01a6383a 100644 --- a/memgpt/server/rest_api/admin/users.py +++ b/memgpt/server/rest_api/admin/users.py @@ -73,14 +73,18 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): """ Create a new user in the database """ + print("REQUEST ID", request.user_id, request.user_id is None, type(request.user_id)) new_user = User( - id=uuid.UUID(request.user_id) if request.user_id is not None else None, + id=None if not request.user_id else uuid.UUID(request.user_id), # TODO can add more fields (name? metadata?) ) try: server.ms.create_user(new_user) + # initialize default presets automatically for user + server.initialize_default_presets(new_user.id) + # make sure we can retrieve the user from the DB too new_user_ret = server.ms.get_user(new_user.id) if new_user_ret is None: @@ -93,7 +97,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") - return CreateUserResponse(user_id=new_user_ret.id, api_key=token.token) + return CreateUserResponse(user_id=str(new_user_ret.id), api_key=token.token) @router.delete("/users", tags=["admin"], response_model=DeleteUserResponse) def delete_user( diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index 0a2b4f23..1134f50f 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer from memgpt.server.rest_api.auth_token import get_current_user +from memgpt.models.pydantic_models import AgentStateModel router = APIRouter() @@ -23,7 +24,8 @@ class AgentRenameRequest(BaseModel): class AgentConfigResponse(BaseModel): - config: dict = Field(..., description="The agent configuration object.") + # config: dict = Field(..., description="The agent configuration object.") + agent_state: AgentStateModel = Field(..., description="The state of the agent.") def validate_agent_name(name: str) -> str: @@ -61,8 +63,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface): agent_id = uuid.UUID(request.agent_id) if request.agent_id else None interface.clear() - config = server.get_agent_config(user_id=user_id, agent_id=agent_id) - return AgentConfigResponse(config=config) + agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id) + return AgentConfigResponse(agent_state=agent_state) @router.patch("/agents/rename", tags=["agents"], response_model=AgentConfigResponse) def update_agent_name( @@ -80,12 +82,12 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface): interface.clear() try: - config = server.rename_agent(user_id=user_id, agent_id=agent_id, new_agent_name=valid_name) + agent_state = server.rename_agent(user_id=user_id, agent_id=agent_id, new_agent_name=valid_name) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") - return AgentConfigResponse(config=config) + return AgentConfigResponse(agent_state=agent_state) @router.delete("/agents", tags=["agents"]) def delete_agent( diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index fa9e7c91..be82ef36 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -8,6 +8,8 @@ from pydantic import BaseModel, Field from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer from memgpt.server.rest_api.auth_token import get_current_user +from memgpt.data_types import AgentState +from memgpt.models.pydantic_models import LLMConfigModel, EmbeddingConfigModel, AgentStateModel router = APIRouter() @@ -22,7 +24,7 @@ class CreateAgentRequest(BaseModel): class CreateAgentResponse(BaseModel): - agent_id: uuid.UUID = Field(..., description="Unique identifier of the newly created agent.") + agent_state: AgentStateModel = Field(..., description="The state of the newly created agent.") def setup_agents_index_router(server: SyncServer, interface: QueuingInterface): @@ -52,9 +54,28 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface): interface.clear() try: - agent_state = server.create_agent(user_id=user_id, agent_config=request.config) - return CreateAgentResponse(agent_id=agent_state.id) + agent_state = server.create_agent(user_id=user_id, **request.config) + llm_config = LLMConfigModel(**vars(agent_state.llm_config)) + embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config)) + return CreateAgentResponse( + agent_state=AgentStateModel( + id=agent_state.id, + name=agent_state.name, + user_id=agent_state.user_id, + preset=agent_state.preset, + persona=agent_state.persona, + human=agent_state.human, + llm_config=llm_config, + embedding_config=embedding_config, + state=agent_state.state, + created_at=int(agent_state.created_at.timestamp()), + ) + ) + # return CreateAgentResponse( + # agent_state=AgentStateModel( + # ) except Exception as e: + print(str(e)) raise HTTPException(status_code=500, detail=str(e)) return router diff --git a/memgpt/server/rest_api/openai_assistants/assistants.py b/memgpt/server/rest_api/openai_assistants/assistants.py index 6336b610..fd2a1c5b 100644 --- a/memgpt/server/rest_api/openai_assistants/assistants.py +++ b/memgpt/server/rest_api/openai_assistants/assistants.py @@ -261,9 +261,6 @@ def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterfac # create a memgpt agent agent_state = server.create_agent( user_id=user_id, - agent_config={ - "user_id": user_id, - }, ) # TODO: insert messages into recall memory return OpenAIThread( diff --git a/memgpt/server/server.py b/memgpt/server/server.py index cd07884a..98e10382 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -206,6 +206,8 @@ class SyncServer(LockingServer): # Initialize the connection to the DB self.config = MemGPTConfig.load() + assert self.config.persona is not None, "Persona must be set in the config" + assert self.config.human is not None, "Human must be set in the config" # TODO figure out how to handle credentials for the server self.credentials = MemGPTCredentials.load() @@ -570,7 +572,12 @@ class SyncServer(LockingServer): def create_agent( self, user_id: uuid.UUID, - agent_config: Union[dict, AgentState], + name: Optional[str] = None, + preset: Optional[str] = None, + persona: Optional[str] = None, + human: Optional[str] = None, + llm_config: Optional[LLMConfig] = None, + embedding_config: Optional[EmbeddingConfig] = None, interface: Union[AgentInterface, None] = None, # persistence_manager: Union[PersistenceManager, None] = None, ) -> AgentState: @@ -578,10 +585,6 @@ class SyncServer(LockingServer): if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") - # Initialize the agent based on the provided configuration - if not isinstance(agent_config, dict): - raise ValueError(f"agent_config must be provided as a dictionary") - if interface is None: # interface = self.default_interface_cls() interface = self.default_interface @@ -596,13 +599,13 @@ class SyncServer(LockingServer): agent_state = AgentState( user_id=user.id, - name=agent_config["name"] if "name" in agent_config else utils.create_random_username(), - preset=agent_config["preset"] if "preset" in agent_config else self.config.preset, + name=name if name else utils.create_random_username(), + preset=preset if preset else self.config.preset, # TODO we need to allow passing raw persona/human text via the server request - persona=agent_config["persona"] if "persona" in agent_config else self.config.persona, - human=agent_config["human"] if "human" in agent_config else self.config.human, - llm_config=agent_config["llm_config"] if "llm_config" in agent_config else self.server_llm_config, - embedding_config=agent_config["embedding_config"] if "embedding_config" in agent_config else self.server_embedding_config, + persona=persona if persona else self.config.persona, + human=human if human else self.config.human, + llm_config=llm_config if llm_config else self.server_llm_config, + embedding_config=embedding_config if embedding_config else self.server_embedding_config, ) # NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation # TODO: fix this db dependency and remove @@ -851,7 +854,7 @@ class SyncServer(LockingServer): # TODO: mark what is in-context versus not return cursor, json_records - def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict: + def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> AgentState: """Return the config of an agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -860,9 +863,7 @@ class SyncServer(LockingServer): # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) - agent_config = vars(memgpt_agent.agent_state) - - return agent_config + return memgpt_agent.agent_state def get_server_config(self, include_defaults: bool = False) -> dict: """Return the base config""" @@ -966,9 +967,7 @@ class SyncServer(LockingServer): logger.exception(f"Failed to update agent name with:\n{str(e)}") raise ValueError(f"Failed to update agent name in database") - # return the new config (only the name should have been updated) - agent_config = self._agent_state_to_config(agent_state=memgpt_agent.agent_state) - return agent_config + return memgpt_agent.agent_state def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID): """Delete an agent in the database""" diff --git a/tests/test_agent_function_update.py b/tests/test_agent_function_update.py index 4570a5e7..45fd3bd2 100644 --- a/tests/test_agent_function_update.py +++ b/tests/test_agent_function_update.py @@ -4,15 +4,15 @@ import os import inspect import uuid -from memgpt import MemGPT from memgpt.config import MemGPTConfig +from memgpt import create_client from memgpt import constants import memgpt.functions.function_sets.base as base_functions from memgpt.functions.functions import USER_FUNCTIONS_DIR from memgpt.utils import assistant_function_to_tool from memgpt.models import chat_completion_response -from tests.utils import wipe_config +from tests.utils import wipe_config, create_config import pytest @@ -32,9 +32,12 @@ def agent(): wipe_config() global client if os.getenv("OPENAI_API_KEY"): - client = MemGPT(quickstart="openai") + create_config("openai") else: - client = MemGPT(quickstart="memgpt_hosted") + create_config("memgpt_hosted") + + # create memgpt client + client = create_client() config = MemGPTConfig.load() @@ -44,11 +47,8 @@ def agent(): client.server.create_user({"id": user_id}) agent_state = client.create_agent( - agent_config={ - # "name": test_agent_id, - "persona": constants.DEFAULT_PERSONA, - "human": constants.DEFAULT_HUMAN, - } + persona=constants.DEFAULT_PERSONA, + human=constants.DEFAULT_HUMAN, ) return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 1b8f1bea..c4280d51 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,11 +1,11 @@ import os import uuid -from memgpt import MemGPT +from memgpt import create_client from memgpt.config import MemGPTConfig from memgpt import constants import memgpt.functions.function_sets.base as base_functions -from .utils import wipe_config +from .utils import wipe_config, create_config # test_agent_id = "test_agent" @@ -18,16 +18,15 @@ def create_test_agent(): wipe_config() global client if os.getenv("OPENAI_API_KEY"): - client = MemGPT(quickstart="openai") + create_config("openai") else: - client = MemGPT(quickstart="memgpt_hosted") + create_config("memgpt_hosted") + + client = create_client() agent_state = client.create_agent( - agent_config={ - # "name": test_agent_id, - "persona": constants.DEFAULT_PERSONA, - "human": constants.DEFAULT_HUMAN, - } + persona=constants.DEFAULT_PERSONA, + human=constants.DEFAULT_HUMAN, ) global agent_obj diff --git a/tests/test_cli.py b/tests/test_cli.py index 37e03dce..e11c17ea 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,8 +6,8 @@ subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"]) import pexpect from .constants import TIMEOUT -from .utils import configure_memgpt -from memgpt import MemGPT +from .utils import create_config, wipe_config +from memgpt import create_client # def test_configure_memgpt(): @@ -17,9 +17,9 @@ from memgpt import MemGPT def test_save_load(): # configure_memgpt() # rely on configure running first^ if os.getenv("OPENAI_API_KEY"): - client = MemGPT(quickstart="openai") + create_config("openai") else: - client = MemGPT(quickstart="memgpt_hosted") + create_config("memgpt_hosted") child = pexpect.spawn("poetry run memgpt run --agent test_save_load --first --strip-ui") diff --git a/tests/test_client.py b/tests/test_client.py index daaf7b74..39c3d33f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,7 +1,9 @@ import uuid +import time import os +import threading -from memgpt import MemGPT +from memgpt import Admin, create_client from memgpt.config import MemGPTConfig from memgpt import constants from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset @@ -9,125 +11,114 @@ from memgpt.functions.functions import load_all_function_sets from memgpt.prompts import gpt_system from memgpt.constants import DEFAULT_PRESET +import pytest + from .utils import wipe_config import uuid test_agent_name = f"test_client_{str(uuid.uuid4())}" -test_preset_name = "test_preset" +# test_preset_name = "test_preset" +test_preset_name = DEFAULT_PRESET test_agent_state = None client = None test_agent_state_post_message = None test_user_id = uuid.uuid4() +test_base_url = "http://localhost:8283" -def test_create_preset(): - wipe_config() - global client - if os.getenv("OPENAI_API_KEY"): - client = MemGPT(quickstart="openai", user_id=test_user_id) +# admin credentials +test_server_token = "test_server_token" + + +def run_server(): + import uvicorn + from memgpt.server.rest_api.server import app + + uvicorn.run(app, host="localhost", port=8283, log_level="info") + + +@pytest.fixture(scope="session", autouse=True) +def start_uvicorn_server(): + """Starts Uvicorn server in a background thread.""" + + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + print("Starting server...") + time.sleep(5) + yield + + +@pytest.fixture(scope="module") +def user_token(): + # Setup: Create a user via the client before the tests + + admin = Admin(test_base_url, test_server_token) + user_id, token = admin.create_user(test_user_id) # Adjust as per your client's method + print(user_id, token) + + yield token + + # Teardown: Delete the user after the test (or after all tests if fixture scope is module/class) + admin.delete_user(test_user_id) # Adjust as per your client's method + + +# Fixture to create clients with different configurations +@pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module") +def client(request, user_token): + # use token or not + if request.param["base_url"]: + token = user_token else: - client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) + token = None - available_functions = load_all_function_sets(merge=True) - functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()] - preset = Preset( - name=test_preset_name, - user_id=test_user_id, - description="A preset for testing the MemGPT client", - system=gpt_system.get_system_text(DEFAULT_PRESET), - functions_schema=functions_schema, - ) - client.create_preset(preset) + client = create_client(**request.param, token=token) # This yields control back to the test function + yield client -def test_create_agent(): - wipe_config() - config = MemGPTConfig.load() +# TODO: add back once REST API supports +# def test_create_preset(client): +# +# available_functions = load_all_function_sets(merge=True) +# functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()] +# preset = Preset( +# name=test_preset_name, +# user_id=test_user_id, +# description="A preset for testing the MemGPT client", +# system=gpt_system.get_system_text(DEFAULT_PRESET), +# functions_schema=functions_schema, +# ) +# client.create_preset(preset) - # ensure user exists - if not client.server.get_user(user_id=test_user_id): - raise ValueError("User failed to be created") +def test_create_agent(client): global test_agent_state test_agent_state = client.create_agent( - agent_config={ - "user_id": test_user_id, - "name": test_agent_name, - "preset": test_preset_name, - } + name=test_agent_name, + preset=test_preset_name, ) print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}") assert test_agent_state is not None -def test_user_message(): +def test_user_message(client): """Test that we can send a message through the client""" assert client is not None, "Run create_agent test first" print(f"\n\n[2] SENDING MESSAGE TO AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}") response = client.user_message(agent_id=test_agent_state.id, message="Hello my name is Test, Client Test") assert response is not None and len(response) > 0 - global test_agent_state_post_message - client.server.active_agents[0]["agent"].update_state() - test_agent_state_post_message = client.server.active_agents[0]["agent"].agent_state - print( - f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}" - ) - - -def test_save_load(): - """Test that state is being persisted correctly after an /exit - - Create a new agent, and request a message - - Then trigger - """ - assert client is not None, "Run create_agent test first" - assert test_agent_state is not None, "Run create_agent test first" - assert test_agent_state_post_message is not None, "Run test_user_message test first" - - # Create a new client (not thread safe), and load the same agent - # The agent state inside should correspond to the initial state pre-message - if os.getenv("OPENAI_API_KEY"): - client2 = MemGPT(quickstart="openai", user_id=test_user_id) - else: - client2 = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) - print(f"\n\n[3] CREATING CLIENT2, LOADING AGENT {test_agent_state.id}!") - client2_agent_obj = client2.server._get_or_load_agent(user_id=test_user_id, agent_id=test_agent_state.id) - client2_agent_state = client2_agent_obj.update_state() - print(f"[3] LOADED AGENT! AGENT {client2_agent_state.id}\n\tmessages={client2_agent_state.state['messages']}") - - # assert test_agent_state == client2_agent_state, f"{vars(test_agent_state)}\n{vars(client2_agent_state)}" - def check_state_equivalence(state_1, state_2): - """Helper function that checks the equivalence of two AgentState objects""" - assert state_1.keys() == state_2.keys(), f"{state_1.keys()}\n{state_2.keys}" - for k, v1 in state_1.items(): - v2 = state_2[k] - if isinstance(v1, LLMConfig) or isinstance(v1, EmbeddingConfig): - assert vars(v1) == vars(v2), f"{vars(v1)}\n{vars(v2)}" - else: - assert v1 == v2, f"{v1}\n{v2}" - - check_state_equivalence(vars(test_agent_state), vars(client2_agent_state)) - - # Now, write out the save from the original client - # This should persist the test message into the agent state - client.save() - - if os.getenv("OPENAI_API_KEY"): - client3 = MemGPT(quickstart="openai", user_id=test_user_id) - else: - client3 = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) - client3_agent_obj = client3.server._get_or_load_agent(user_id=test_user_id, agent_id=test_agent_state.id) - client3_agent_state = client3_agent_obj.update_state() - - check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state)) + # global test_agent_state_post_message + # client.server.active_agents[0]["agent"].update_state() + # test_agent_state_post_message = client.server.active_agents[0]["agent"].agent_state + # print( + # f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}" + # ) if __name__ == "__main__": - test_create_preset() + # test_create_preset() test_create_agent() test_user_message() diff --git a/tests/test_different_embedding_size.py b/tests/test_different_embedding_size.py index b021c8cc..6d369424 100644 --- a/tests/test_different_embedding_size.py +++ b/tests/test_different_embedding_size.py @@ -1,13 +1,13 @@ import uuid import os -from memgpt import MemGPT +from memgpt import create_client from memgpt.config import MemGPTConfig from memgpt import constants from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage from memgpt.embeddings import embedding_model from memgpt.agent_store.storage import StorageConnector, TableType -from .utils import wipe_config +from .utils import wipe_config, create_config import uuid @@ -49,14 +49,12 @@ def test_create_user(): wipe_config() # create client - client = MemGPT(quickstart="openai", user_id=test_user_id) + create_config("openai") + client = create_client() # openai: create agent openai_agent = client.create_agent( - { - "user_id": test_user_id, - "name": "openai_agent", - } + name="openai_agent", ) assert ( openai_agent.embedding_config.embedding_endpoint_type == "openai" @@ -69,16 +67,13 @@ def test_create_user(): # hosted: create agent hosted_agent = client.create_agent( - { - "user_id": test_user_id, - "name": "hosted_agent", - "embedding_config": EmbeddingConfig( - embedding_endpoint_type="hugging-face", - embedding_model="BAAI/bge-large-en-v1.5", - embedding_endpoint="https://embeddings.memgpt.ai", - embedding_dim=1024, - ), - } + name="hosted_agent", + embedding_config=EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_model="BAAI/bge-large-en-v1.5", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_dim=1024, + ), ) # check to make sure endpoint overriden assert ( diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index c77f9f76..786976c3 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -13,8 +13,8 @@ from memgpt.config import MemGPTConfig from memgpt.credentials import MemGPTCredentials from memgpt.metadata import MetadataStore from memgpt.data_types import User, AgentState, EmbeddingConfig -from memgpt import MemGPT -from .utils import wipe_config +from memgpt import create_client +from .utils import wipe_config, create_config @pytest.fixture(autouse=True) @@ -38,6 +38,7 @@ def recreate_declarative_base(): @pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"]) @pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"]) def test_load_directory(metadata_storage_connector, passage_storage_connector, clear_dynamically_created_models, recreate_declarative_base): + wipe_config() # setup config config = MemGPTConfig() if metadata_storage_connector == "postgres": @@ -74,7 +75,6 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c openai_key=os.getenv("OPENAI_API_KEY"), ) credentials.save() - client = MemGPT(quickstart="openai", user_id=user.id) embedding_config = EmbeddingConfig( embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", @@ -83,7 +83,6 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c ) else: - client = MemGPT(quickstart="memgpt_hosted", user_id=user.id) embedding_config = EmbeddingConfig( embedding_endpoint_type="local", embedding_endpoint=None, diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 89d10ef5..fa9dec66 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -1,6 +1,5 @@ import os from memgpt.migrate import migrate_all_agents, migrate_all_sources -from memgpt import MemGPT from memgpt.config import MemGPTConfig from .utils import wipe_config from memgpt.server.server import SyncServer diff --git a/tests/test_openai_assistant_api.py b/tests/test_openai_assistant_api.py index 6000df20..4fc02bf9 100644 --- a/tests/test_openai_assistant_api.py +++ b/tests/test_openai_assistant_api.py @@ -7,49 +7,49 @@ from memgpt.server.rest_api.server import app from memgpt.constants import DEFAULT_PRESET from memgpt.config import MemGPTConfig - -def test_list_messages(): - client = TestClient(app) - - test_user_id = uuid.UUID(MemGPTConfig.load().anon_clientid) - - # create user - server = SyncServer() - if not server.get_user(test_user_id): - server.create_user({"id": test_user_id}) - - # write default presets to DB - server.initialize_default_presets(test_user_id) - - # test: create agent - request_body = { - "user_id": str(test_user_id), - "assistant_name": DEFAULT_PRESET, - } - print(request_body) - response = client.post("/v1/threads", json=request_body) - assert response.status_code == 200, f"Error: {response.json()}" - agent_id = response.json()["id"] - print(response.json()) - - # test: insert messages - # TODO: eventually implement the "run" functionality - request_body = { - "user_id": str(test_user_id), - "content": "Hello, world!", - "role": "user", - } - response = client.post(f"/v1/threads/{str(agent_id)}/messages", json=request_body) - assert response.status_code == 200, f"Error: {response.json()}" - - # test: list messages - thread_id = str(agent_id) - params = { - "limit": 10, - "order": "desc", - # "after": "", - "user_id": str(test_user_id), - } - response = client.get(f"/v1/threads/{thread_id}/messages", params=params) - assert response.status_code == 200, f"Error: {response.json()}" - print(response.json()) +# TODO: modify this to run against an actual running server +# def test_list_messages(): +# client = TestClient(app) +# +# test_user_id = uuid.UUID(MemGPTConfig.load().anon_clientid) +# +# # create user +# server = SyncServer() +# if not server.get_user(test_user_id): +# print("Creating user in test_list_messages", test_user_id) +# server.create_user({"id": test_user_id}) +# else: +# print("User already exists in test_list_messages", test_user_id) +# +# # write default presets to DB +# server.initialize_default_presets(test_user_id) +# +# # test: create agent +# request_body = { +# "assistant_name": DEFAULT_PRESET, +# } +# print(request_body) +# response = client.post("/v1/threads", json=request_body) +# assert response.status_code == 200, f"Error: {response.json()}" +# agent_id = response.json()["id"] +# print(response.json()) +# +# # test: insert messages +# # TODO: eventually implement the "run" functionality +# request_body = { +# "content": "Hello, world!", +# "role": "user", +# } +# response = client.post(f"/v1/threads/{str(agent_id)}/messages", json=request_body) +# assert response.status_code == 200, f"Error: {response.json()}" +# +# # test: list messages +# thread_id = str(agent_id) +# params = { +# "limit": 10, +# "order": "desc", +# } +# response = client.get(f"/v1/threads/{thread_id}/messages", params=params) +# assert response.status_code == 200, f"Error: {response.json()}" +# print(response.json()) +# diff --git a/tests/test_persistence.py b/tests/test_persistence.py new file mode 100644 index 00000000..0d1126ed --- /dev/null +++ b/tests/test_persistence.py @@ -0,0 +1,52 @@ +# test state saving between client session +# TODO: update this test with correct imports + + +# def test_save_load(client): +# """Test that state is being persisted correctly after an /exit +# +# Create a new agent, and request a message +# +# Then trigger +# """ +# assert client is not None, "Run create_agent test first" +# assert test_agent_state is not None, "Run create_agent test first" +# assert test_agent_state_post_message is not None, "Run test_user_message test first" +# +# # Create a new client (not thread safe), and load the same agent +# # The agent state inside should correspond to the initial state pre-message +# if os.getenv("OPENAI_API_KEY"): +# client2 = MemGPT(quickstart="openai", user_id=test_user_id) +# else: +# client2 = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) +# print(f"\n\n[3] CREATING CLIENT2, LOADING AGENT {test_agent_state.id}!") +# client2_agent_obj = client2.server._get_or_load_agent(user_id=test_user_id, agent_id=test_agent_state.id) +# client2_agent_state = client2_agent_obj.update_state() +# print(f"[3] LOADED AGENT! AGENT {client2_agent_state.id}\n\tmessages={client2_agent_state.state['messages']}") +# +# # assert test_agent_state == client2_agent_state, f"{vars(test_agent_state)}\n{vars(client2_agent_state)}" +# def check_state_equivalence(state_1, state_2): +# """Helper function that checks the equivalence of two AgentState objects""" +# assert state_1.keys() == state_2.keys(), f"{state_1.keys()}\n{state_2.keys}" +# for k, v1 in state_1.items(): +# v2 = state_2[k] +# if isinstance(v1, LLMConfig) or isinstance(v1, EmbeddingConfig): +# assert vars(v1) == vars(v2), f"{vars(v1)}\n{vars(v2)}" +# else: +# assert v1 == v2, f"{v1}\n{v2}" +# +# check_state_equivalence(vars(test_agent_state), vars(client2_agent_state)) +# +# # Now, write out the save from the original client +# # This should persist the test message into the agent state +# client.save() +# +# if os.getenv("OPENAI_API_KEY"): +# client3 = MemGPT(quickstart="openai", user_id=test_user_id) +# else: +# client3 = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) +# client3_agent_obj = client3.server._get_or_load_agent(user_id=test_user_id, agent_id=test_agent_state.id) +# client3_agent_state = client3_agent_obj.update_state() +# +# check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state)) +# diff --git a/tests/test_server.py b/tests/test_server.py index a855fc1f..b140ddb7 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -15,6 +15,7 @@ from .utils import wipe_config, wipe_memgpt_home def test_server(): load_dotenv() + wipe_config() wipe_memgpt_home() # Use os.getenv with a fallback to os.environ.get @@ -93,7 +94,10 @@ def test_server(): # create agent agent_state = server.create_agent( user_id=user.id, - agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"), + name="test_agent", + preset="memgpt_chat", + human="cs_phd", + persona="sam_pov", ) print(f"Created agent\n{agent_state}") diff --git a/tests/test_summarize.py b/tests/test_summarize.py index e4b5a03a..82f617bc 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -1,16 +1,15 @@ import os import uuid -from memgpt import MemGPT +from memgpt import create_client from memgpt.config import MemGPTConfig from memgpt import constants import memgpt.functions.function_sets.base as base_functions -from .utils import wipe_config +from .utils import wipe_config, create_config # test_agent_id = "test_agent" test_agent_name = f"test_client_{str(uuid.uuid4())}" -test_user_id = uuid.uuid4() client = None agent_obj = None @@ -20,22 +19,20 @@ def create_test_agent(): wipe_config() global client if os.getenv("OPENAI_API_KEY"): - client = MemGPT(quickstart="openai", user_id=test_user_id) + create_config("openai") else: - client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) + create_config("memgpt_hosted") + client = create_client() agent_state = client.create_agent( - agent_config={ - "user_id": test_user_id, - "name": test_agent_name, - "persona": constants.DEFAULT_PERSONA, - "human": constants.DEFAULT_HUMAN, - } + name=test_agent_name, + persona=constants.DEFAULT_PERSONA, + human=constants.DEFAULT_HUMAN, ) global agent_obj config = MemGPTConfig.load() - agent_obj = client.server._get_or_load_agent(user_id=test_user_id, agent_id=agent_state.id) + agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id) def test_summarize(): diff --git a/tests/utils.py b/tests/utils.py index 1de59779..5225902f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,10 +2,22 @@ import datetime import os from memgpt.config import MemGPTConfig +from memgpt.cli.cli import quickstart, QuickstartChoice +from memgpt import Admin from .constants import TIMEOUT +def create_config(endpoint="openai"): + """Create config file matching quickstart option""" + if endpoint == "openai": + quickstart(QuickstartChoice.openai) + elif endpoint == "memgpt_hosted": + quickstart(QuickstartChoice.memgpt_hosted) + else: + raise ValueError(f"Invalid endpoint {endpoint}") + + def wipe_config(): if MemGPTConfig.exists(): # delete