From ef913a743bc3b7731552cd7a0dd96ed153b80dd1 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 27 Mar 2024 17:09:11 -0700 Subject: [PATCH] feat: move quickstart to use inference.memgpt.ai (#1190) --- configs/memgpt_hosted.json | 6 +-- memgpt/cli/cli.py | 69 +++++++++++++++----------- memgpt/client/client.py | 4 -- memgpt/config.py | 5 +- memgpt/configs/memgpt_hosted.json | 6 +-- memgpt/configs/openai.json | 4 +- memgpt/data_types.py | 14 +++--- memgpt/llm_api_tools.py | 3 +- memgpt/server/server.py | 3 +- tests/test_base_functions.py | 23 +++------ tests/test_client.py | 6 ++- tests/test_different_embedding_size.py | 10 ++-- tests/test_load_archival.py | 17 ++++++- tests/test_server.py | 1 + tests/test_storage.py | 14 +++++- 15 files changed, 106 insertions(+), 79 deletions(-) diff --git a/configs/memgpt_hosted.json b/configs/memgpt_hosted.json index fb575c10..ea8030ac 100644 --- a/configs/memgpt_hosted.json +++ b/configs/memgpt_hosted.json @@ -1,9 +1,7 @@ { "context_window": 16384, - "model": "ehartford/dolphin-2.5-mixtral-8x7b", - "model_endpoint_type": "vllm", - "model_endpoint": "https://api.memgpt.ai", - "model_wrapper": "chatml", + "model_endpoint_type": "openai", + "model_endpoint": "https://inference.memgpt.ai", "embedding_endpoint_type": "hugging-face", "embedding_endpoint": "https://embeddings.memgpt.ai", "embedding_model": "BAAI/bge-large-en-v1.5", diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index f6172cc4..fe1361c5 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -75,27 +75,46 @@ def set_config_with_dict(new_config: dict) -> (MemGPTConfig, bool): printd(f"Skipping new config {k}: {v} == {new_config[k]}") # update embedding config - for k, v in vars(old_config.default_embedding_config).items(): - if k in new_config: - if v != new_config[k]: - printd(f"Replacing config {k}: {v} -> {new_config[k]}") - modified = True - # old_config[k] = new_config[k] - setattr(old_config.default_embedding_config, k, new_config[k]) - else: - printd(f"Skipping new config {k}: {v} == {new_config[k]}") + if old_config.default_embedding_config: + for k, v in vars(old_config.default_embedding_config).items(): + if k in new_config: + if v != new_config[k]: + printd(f"Replacing config {k}: {v} -> {new_config[k]}") + modified = True + # old_config[k] = new_config[k] + setattr(old_config.default_embedding_config, k, new_config[k]) + else: + printd(f"Skipping new config {k}: {v} == {new_config[k]}") + else: + modified = True + fields = ["embedding_model", "embedding_dim", "embedding_chunk_size", "embedding_endpoint", "embedding_endpoint_type"] + args = {} + for field in fields: + if field in new_config: + args[field] = new_config[field] + printd(f"Setting new config {field}: {new_config[field]}") + old_config.default_embedding_config = EmbeddingConfig(**args) # update llm config - for k, v in vars(old_config.default_llm_config).items(): - if k in new_config: - if v != new_config[k]: - printd(f"Replacing config {k}: {v} -> {new_config[k]}") - modified = True - # old_config[k] = new_config[k] - setattr(old_config.default_llm_config, k, new_config[k]) - else: - printd(f"Skipping new config {k}: {v} == {new_config[k]}") - + if old_config.default_llm_config: + for k, v in vars(old_config.default_llm_config).items(): + if k in new_config: + if v != new_config[k]: + printd(f"Replacing config {k}: {v} -> {new_config[k]}") + modified = True + # old_config[k] = new_config[k] + setattr(old_config.default_llm_config, k, new_config[k]) + else: + printd(f"Skipping new config {k}: {v} == {new_config[k]}") + else: + modified = True + fields = ["model", "model_endpoint", "model_endpoint_type", "model_wrapper", "context_window"] + args = {} + for field in fields: + if field in new_config: + args[field] = new_config[field] + printd(f"Setting new config {field}: {new_config[field]}") + old_config.default_llm_config = LLMConfig(**args) return (old_config, modified) @@ -153,10 +172,13 @@ def quickstart( else: # Load the file from the relative path script_dir = os.path.dirname(__file__) # Get the directory where the script is located + print("SCRIPT", script_dir) backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json") + print("FILE PATH", backup_config_path) try: with open(backup_config_path, "r", encoding="utf-8") as file: backup_config = json.load(file) + print(backup_config) printd("Loaded config file successfully.") new_config, config_was_modified = set_config_with_dict(backup_config) except FileNotFoundError: @@ -301,15 +323,6 @@ def server( # # Add the handler to the logger # server_logger.addHandler(stream_handler) - # override config with postgres enviornment (messy, but necessary for docker compose) - if os.getenv("POSTGRES_URI"): - config = MemGPTConfig.load() - config.archival_storage_uri = os.getenv("POSTGRES_URI") - config.recall_storage_uri = os.getenv("POSTGRES_URI") - config.metadata_storage_uri = os.getenv("POSTGRES_URI") - print(f"Overriding DB config URI with enviornment variable: {config.archival_storage_uri}") - config.save() - if type == ServerChoice.rest_api: import uvicorn from memgpt.server.rest_api.server import app diff --git a/memgpt/client/client.py b/memgpt/client/client.py index ee7507a9..08ecf1db 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -551,8 +551,6 @@ class LocalClient(AbstractClient): preset: Optional[str] = None, persona: Optional[str] = None, human: Optional[str] = None, - embedding_config: Optional[EmbeddingConfig] = None, - llm_config: Optional[LLMConfig] = None, ) -> AgentState: if name and self.agent_exists(agent_name=name): raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})") @@ -564,8 +562,6 @@ class LocalClient(AbstractClient): preset=preset, persona=persona, human=human, - embedding_config=embedding_config, - llm_config=llm_config, ) return agent_state diff --git a/memgpt/config.py b/memgpt/config.py index 0f0bc410..1aff760d 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -44,10 +44,10 @@ class MemGPTConfig: human: str = DEFAULT_HUMAN # model parameters - default_llm_config: LLMConfig = field(default_factory=LLMConfig) + default_llm_config: LLMConfig = None # embedding parameters - default_embedding_config: EmbeddingConfig = field(default_factory=EmbeddingConfig) + default_embedding_config: EmbeddingConfig = None # database configs: archival archival_storage_type: str = "chroma" # local, db @@ -110,6 +110,7 @@ class MemGPTConfig: # insure all configuration directories exist cls.create_config_dir() + print(f"Loading config from {config_path}") if os.path.exists(config_path): # read existing config config.read(config_path) diff --git a/memgpt/configs/memgpt_hosted.json b/memgpt/configs/memgpt_hosted.json index fb575c10..013ffe8d 100644 --- a/memgpt/configs/memgpt_hosted.json +++ b/memgpt/configs/memgpt_hosted.json @@ -1,8 +1,8 @@ { "context_window": 16384, - "model": "ehartford/dolphin-2.5-mixtral-8x7b", - "model_endpoint_type": "vllm", - "model_endpoint": "https://api.memgpt.ai", + "model": "memgpt", + "model_endpoint_type": "openai", + "model_endpoint": "https://inference.memgpt.ai", "model_wrapper": "chatml", "embedding_endpoint_type": "hugging-face", "embedding_endpoint": "https://embeddings.memgpt.ai", diff --git a/memgpt/configs/openai.json b/memgpt/configs/openai.json index 7c76b101..82ed0d72 100644 --- a/memgpt/configs/openai.json +++ b/memgpt/configs/openai.json @@ -6,7 +6,7 @@ "model_wrapper": null, "embedding_endpoint_type": "openai", "embedding_endpoint": "https://api.openai.com/v1", - "embedding_model": null, + "embedding_model": "text-embedding-ada-002", "embedding_dim": 1536, "embedding_chunk_size": 300 -} \ No newline at end of file +} diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 2e09588d..d4056466 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -370,9 +370,9 @@ class Passage(Record): class LLMConfig: def __init__( self, - model: Optional[str] = "gpt-4", - model_endpoint_type: Optional[str] = "openai", - model_endpoint: Optional[str] = "https://api.openai.com/v1", + model: Optional[str] = None, + model_endpoint_type: Optional[str] = None, + model_endpoint: Optional[str] = None, model_wrapper: Optional[str] = None, context_window: Optional[int] = None, ): @@ -391,10 +391,10 @@ class LLMConfig: class EmbeddingConfig: def __init__( self, - embedding_endpoint_type: Optional[str] = "openai", - embedding_endpoint: Optional[str] = "https://api.openai.com/v1", - embedding_model: Optional[str] = "text-embedding-ada-002", - embedding_dim: Optional[int] = 1536, + embedding_endpoint_type: Optional[str] = None, + embedding_endpoint: Optional[str] = None, + embedding_model: Optional[str] = None, + embedding_dim: Optional[int] = None, embedding_chunk_size: Optional[int] = 300, ): self.embedding_endpoint_type = embedding_endpoint_type diff --git a/memgpt/llm_api_tools.py b/memgpt/llm_api_tools.py index dbfe81dc..b748aeed 100644 --- a/memgpt/llm_api_tools.py +++ b/memgpt/llm_api_tools.py @@ -416,7 +416,8 @@ def create( # openai if agent_state.llm_config.model_endpoint_type == "openai": # TODO do the same for Azure? - if credentials.openai_key is None: + if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1": + # only is a problem if we are *not* using an openai proxy raise ValueError(f"OpenAI key is missing from MemGPT config file") if use_tool_naming: data = dict( diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 183d578e..f1e0fd22 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -240,7 +240,8 @@ class SyncServer(LockingServer): embedding_endpoint_type=self.config.default_embedding_config.embedding_endpoint_type, embedding_endpoint=self.config.default_embedding_config.embedding_endpoint, embedding_dim=self.config.default_embedding_config.embedding_dim, - # openai_key=self.credentials.openai_key, + embedding_model=self.config.default_embedding_config.embedding_model, + embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size, ) # Initialize the metadata store diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 6b655796..accf9ede 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -1,4 +1,5 @@ import os +import pytest import uuid from memgpt import create_client @@ -10,10 +11,10 @@ from .utils import wipe_config, create_config # test_agent_id = "test_agent" client = None -agent_obj = None -def create_test_agent(): +@pytest.fixture(scope="module") +def agent_obj(): """Create a test agent that we can call functions on""" wipe_config() global client @@ -31,26 +32,18 @@ def create_test_agent(): global agent_obj user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid) agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id) + yield agent_obj + + client.delete_agent(agent_obj.agent_state.id) -def test_archival(): - global agent_obj - if agent_obj is None: - create_test_agent() - assert agent_obj is not None - +def test_archival(agent_obj): base_functions.archival_memory_insert(agent_obj, "banana") - base_functions.archival_memory_search(agent_obj, "banana") base_functions.archival_memory_search(agent_obj, "banana", page=0) -def test_recall(): - global agent_obj - if agent_obj is None: - create_test_agent() - +def test_recall(agent_obj): base_functions.conversation_search(agent_obj, "banana") base_functions.conversation_search(agent_obj, "banana", page=0) - base_functions.conversation_search_date(agent_obj, start_date="2022-01-01", end_date="2022-01-02") diff --git a/tests/test_client.py b/tests/test_client.py index 4b87da30..41f35146 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -59,6 +59,7 @@ def run_server(): embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", embedding_dim=1536, + embedding_model="text-embedding-ada-002", ), # llms default_llm_config=LLMConfig( @@ -237,9 +238,10 @@ def test_config(client, agent): models_response = client.list_models() print("MODELS", models_response) - config_response = client.get_config() + # TODO: add back + # config_response = client.get_config() # TODO: ensure config is the same as the one in the server - print("CONFIG", config_response) + # print("CONFIG", config_response) def test_sources(client, agent): diff --git a/tests/test_different_embedding_size.py b/tests/test_different_embedding_size.py index ffeb3624..0fdde27d 100644 --- a/tests/test_different_embedding_size.py +++ b/tests/test_different_embedding_size.py @@ -68,15 +68,13 @@ def test_create_user(): openai_agent_run = client.server._get_or_load_agent(user_id=client.user.id, agent_id=openai_agent.id) openai_agent_run.persistence_manager.archival_memory.storage.insert_many(passages) + # create client + create_config("memgpt_hosted") + client = create_client() + # hosted: create agent hosted_agent = client.create_agent( name="hosted_agent", - embedding_config=EmbeddingConfig( - embedding_endpoint_type="hugging-face", - embedding_model="BAAI/bge-large-en-v1.5", - embedding_endpoint="https://embeddings.memgpt.ai", - embedding_dim=1024, - ), ) # check to make sure endpoint overriden assert ( diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 7fa56cdc..09bc29dc 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -11,10 +11,10 @@ from memgpt.cli.cli_load import load_directory # from memgpt.data_sources.connectors import DirectoryConnector, load_data from memgpt.credentials import MemGPTCredentials from memgpt.metadata import MetadataStore -from memgpt.data_types import User, AgentState, EmbeddingConfig +from memgpt.data_types import User, AgentState, EmbeddingConfig, LLMConfig from memgpt.utils import get_human_text, get_persona_text from tests import TEST_MEMGPT_CONFIG -from .utils import wipe_config +from .utils import wipe_config, create_config @pytest.fixture(autouse=True) @@ -44,6 +44,18 @@ def test_load_directory( recreate_declarative_base, ): wipe_config() + TEST_MEMGPT_CONFIG.default_embedding_config = EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + embedding_model="text-embedding-ada-002", + ) + TEST_MEMGPT_CONFIG.default_llm_config = LLMConfig( + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model="gpt-4", + ) + # setup config if metadata_storage_connector == "postgres": if not os.getenv("MEMGPT_PGURI"): @@ -84,6 +96,7 @@ def test_load_directory( embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", embedding_dim=1536, + embedding_model="text-embedding-ada-002", # openai_key=os.getenv("OPENAI_API_KEY"), ) diff --git a/tests/test_server.py b/tests/test_server.py index 4704849d..cc8820b7 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -35,6 +35,7 @@ def server(): default_embedding_config=EmbeddingConfig( embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", + embedding_model="text-embedding-ada-002", embedding_dim=1536, ), # llms diff --git a/tests/test_storage.py b/tests/test_storage.py index d592792c..9450f26c 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -5,7 +5,7 @@ import pytest from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.embeddings import embedding_model, query_embedding -from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState +from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, LLMConfig from memgpt.credentials import MemGPTCredentials from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.metadata import MetadataStore @@ -120,7 +120,17 @@ def test_storage( # if 'Message' in globals(): # print("Removing messages", globals()['Message']) # del globals()['Message'] - + TEST_MEMGPT_CONFIG.default_embedding_config = EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + embedding_model="text-embedding-ada-002", + ) + TEST_MEMGPT_CONFIG.default_llm_config = LLMConfig( + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model="gpt-4", + ) if storage_connector == "postgres": if not os.getenv("MEMGPT_PGURI"): print("Skipping test, missing PG URI")